From e81c5e677df00b32106efce4fa39666651007406 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 3 Aug 2023 10:39:02 +0200
Subject: [PATCH] Implement token hashing in communication server components.

---
 declearn/communication3/server/_files.py | 22 ++++++++++++++++------
 declearn/communication3/server/_peers.py | 22 +++++++++++++++++-----
 declearn/communication3/server/_utils.py | 18 +++++++++++++++++-
 3 files changed, 50 insertions(+), 12 deletions(-)

diff --git a/declearn/communication3/server/_files.py b/declearn/communication3/server/_files.py
index 309f1d40..3b62896e 100644
--- a/declearn/communication3/server/_files.py
+++ b/declearn/communication3/server/_files.py
@@ -22,7 +22,11 @@ import os
 import secrets
 from typing import Dict, Iterator, Optional, Tuple, Union
 
-from declearn.communication3.server._utils import create_secret_token
+from declearn.communication3.server._utils import (
+    create_secret_token,
+    hash_secret_token,
+    verify_token,
+)
 
 __all__ = [
     "FilesManager",
@@ -67,8 +71,10 @@ class SavedFile:
             Path to the file being watched and managed.
         access_token:
             Secret token required to open the file and read its contents.
+            Only a hash of this token will be stored after instantiation.
         delete_token:
             Secret token required to perform non-automated file deletion.
+            Only a hash of this token will be stored after instantiation.
         max_accesses:
             Optional maximum number of accesses, beyond which to disable
             any new file access, and enable automatic deletion.
@@ -78,8 +84,8 @@ class SavedFile:
         """
         # arguments serve modularity; pylint: disable=too-many-arguments
         self._path = path
-        self._access_token = access_token
-        self._delete_token = delete_token
+        self._access_token = hash_secret_token(access_token)
+        self._delete_token = hash_secret_token(delete_token)
         self._max_accesses = max_accesses
         self._num_accesses = 0
         self._srt_datetime = datetime.datetime.now()
@@ -138,7 +144,7 @@ class SavedFile:
         # Verify that the access is authorized.
         if self.expired:
             raise FileExpiredError("Cannot access expired file.")
-        if not secrets.compare_digest(token, self._access_token):
+        if not verify_token(token, self._access_token):
             raise FileAuthError("Incorrect access token.")
         # Increment counters on total number of accesses and current opens.
         self._num_accesses += 1
@@ -175,8 +181,12 @@ class SavedFile:
           In other words, do not let the caller know that the file has
           already been deleted.
         """
-        if not secrets.compare_digest(token, self._delete_token):
+        if not verify_token(token, self._delete_token):
             raise FileAuthError("Incorrect deletion token.")
+        self._delete()
+
+    def _delete(self):
+        """Backend method to delete this file from disk."""
         if self._n_open_reads:
             self._delete_next = True
         elif os.path.isfile(self._path):
@@ -203,7 +213,7 @@ class SavedFile:
     ) -> None:
         """Automatically delete this file from disk if it has expired."""
         if self.expired:
-            self.delete(self._delete_token)
+            self._delete()
 
 
 class FilesManager:
diff --git a/declearn/communication3/server/_peers.py b/declearn/communication3/server/_peers.py
index f2054de6..d2fdfa88 100644
--- a/declearn/communication3/server/_peers.py
+++ b/declearn/communication3/server/_peers.py
@@ -19,10 +19,13 @@
 
 import copy
 import dataclasses
-import secrets
 from typing import Dict, Optional, Set, Tuple
 
-from declearn.communication3.server._utils import create_secret_token
+from declearn.communication3.server._utils import (
+    create_secret_token,
+    hash_secret_token,
+    verify_token,
+)
 
 __all__ = [
     "PeersManager",
@@ -76,6 +79,14 @@ class Peer:
             return peer not in self.policy.blocked
         return peer in self.policy.authorized
 
+    def hash_secret_token(
+        self,
+    ) -> str:
+        """Hash this peer's token in place, and pop its initial value."""
+        token = self.token
+        self.token = hash_secret_token(token)
+        return token
+
 
 class PeersManager:
     """Component to manage a network of peers for a communication server."""
@@ -97,7 +108,8 @@ class PeersManager:
             )
             name = f"{name}.{count}"
         self._peers[name] = peer = Peer(name, policy=policy or PeerPolicy())
-        return name, peer.token
+        token = peer.hash_secret_token()
+        return name, token
 
     def authenticate_peer(
         self,
@@ -106,9 +118,9 @@ class PeersManager:
     ) -> bool:
         """Authenticate a peer based on their secret token."""
         if name in self._peers:
-            return secrets.compare_digest(token, self._peers[name].token)
+            return verify_token(token, self._peers[name].token)
         # Perform a comparison, to avoid peer name leakage via a time attack.
-        secrets.compare_digest(token, MOCK_TOKEN)
+        verify_token(token, MOCK_TOKEN)
         return False
 
     def _get_peer(
diff --git a/declearn/communication3/server/_utils.py b/declearn/communication3/server/_utils.py
index 8f58750d..7d8e12b6 100644
--- a/declearn/communication3/server/_utils.py
+++ b/declearn/communication3/server/_utils.py
@@ -17,10 +17,14 @@
 
 """Shared utils for network communication components."""
 
+import hashlib
 import secrets
 
+
 __all__ = [
     "create_secret_token",
+    "hash_secret_token",
+    "verify_token",
 ]
 
 
@@ -28,5 +32,17 @@ DEFAULT_TOKEN_NBYTES = 64  # Default number of bytes in user auth tokens.
 
 
 def create_secret_token() -> str:
-    """Create a random secret string token."""
+    """Create a random secret token."""
     return secrets.token_hex(nbytes=DEFAULT_TOKEN_NBYTES)
+
+
+def hash_secret_token(token: str) -> str:
+    """Return the hash value of an input string token."""
+    hashed = hashlib.sha512(token.encode("utf-8"), usedforsecurity=True)
+    return hashed.hexdigest()
+
+
+def verify_token(token: str, truth: str) -> bool:
+    """Verify if an input token matches a stored hashed one."""
+    token = hash_secret_token(token)
+    return secrets.compare_digest(token, truth)
-- 
GitLab