From edd62dcdf5b8e48e49d8252a623c9b99037503e3 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 5 Jul 2024 16:16:35 +0200
Subject: [PATCH] Add 'test_utils' setup routines for mock network and masking
 secagg.

---
 declearn/test_utils/__init__.py |  7 +++-
 declearn/test_utils/_network.py | 51 ++++++++++++++++++++++++++-
 declearn/test_utils/_secagg.py  | 61 +++++++++++++++++++++++++++++++++
 3 files changed, 117 insertions(+), 2 deletions(-)
 create mode 100644 declearn/test_utils/_secagg.py

diff --git a/declearn/test_utils/__init__.py b/declearn/test_utils/__init__.py
index 37915bf7..f05bbb7b 100644
--- a/declearn/test_utils/__init__.py
+++ b/declearn/test_utils/__init__.py
@@ -38,7 +38,12 @@ from ._assertions import (
 from ._convert import to_numpy
 from ._gen_ssl import generate_ssl_certificates
 from ._imports import make_importable
-from ._network import MockNetworkClient, MockNetworkServer
+from ._network import (
+    MockNetworkClient,
+    MockNetworkServer,
+    setup_mock_network_endpoints,
+)
+from ._secagg import build_secagg_controllers
 from ._vectors import (
     FrameworkType,
     GradientsTestCase,
diff --git a/declearn/test_utils/_network.py b/declearn/test_utils/_network.py
index db051532..e0ad213c 100644
--- a/declearn/test_utils/_network.py
+++ b/declearn/test_utils/_network.py
@@ -18,9 +18,13 @@
 """Fake network communication endpoints relying on shared memory objects."""
 
 import asyncio
+import contextlib
 import logging
 import uuid
-from typing import Dict, Mapping, Optional, Set, TypeVar, Union
+from typing import (
+    # fmt: off
+    AsyncIterator, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union
+)
 
 
 from declearn.communication.api import NetworkClient, NetworkServer
@@ -31,6 +35,7 @@ from declearn.messaging import Message, SerializedMessage
 __all__ = [
     "MockNetworkClient",
     "MockNetworkServer",
+    "setup_mock_network_endpoints",
 ]
 
 
@@ -179,3 +184,47 @@ class MockNetworkClient(NetworkClient, register=False):
     ) -> SerializedMessage:
         # Force the use of a timeout, to prevent tests from being stuck.
         return await super().recv_message(timeout=timeout or 5)
+
+
+@contextlib.asynccontextmanager
+async def setup_mock_network_endpoints(
+    n_peers: int,
+    port: int = 8765,
+) -> AsyncIterator[Tuple[MockNetworkServer, List[MockNetworkClient]]]:
+    """Instantiate, start and register mock network communication endpoints.
+
+    This is an async context manager, that returns network endpoints,
+    and ensures they are all properly closed upon leaving the context.
+
+    Parameters
+    ----------
+    n_peers:
+        Number of client endpoints to instantiate.
+    port:
+        Mock port number to use.
+
+    Returns
+    -------
+    server:
+        `MockNetworkServer` instance to which clients are registered.
+    clients:
+        List of `MockNetworkClient` instances, registered to the server.
+    """
+    # Instantiate the endpoints.
+    server = MockNetworkServer(port=port)
+    clients = [
+        MockNetworkClient(f"mock://localhost:{port}", name=f"client_{i}")
+        for i in range(n_peers)
+    ]
+    async with contextlib.AsyncExitStack() as stack:
+        # Start the endpoints and ensure they will be properly closed.
+        await stack.enter_async_context(server)  # type: ignore
+        for client in clients:
+            await stack.enter_async_context(client)  # type: ignore
+        # Register the clients with the server.
+        await asyncio.gather(
+            server.wait_for_clients(n_peers),
+            *[client.register() for client in clients],
+        )
+        # Yield the started, registered endpoints.
+        yield server, clients
diff --git a/declearn/test_utils/_secagg.py b/declearn/test_utils/_secagg.py
new file mode 100644
index 00000000..73c6620f
--- /dev/null
+++ b/declearn/test_utils/_secagg.py
@@ -0,0 +1,61 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Routine to set up some SecAgg controllers."""
+
+import secrets
+from typing import List, Tuple
+
+
+from declearn.secagg.masking import MaskingDecrypter, MaskingEncrypter
+
+__all__ = [
+    "build_secagg_controllers",
+]
+
+
+def build_secagg_controllers(
+    n_peers: int,
+) -> Tuple[MaskingDecrypter, List[MaskingEncrypter]]:
+    """Setup aligned masking-based encrypters and decrypter.
+
+    Parameters
+    ----------
+    n_peers:
+        Number of clients for which to set up an encrypter.
+
+    Returns
+    -------
+    decrypter:
+        `MaskingDecrypter` instance.
+    encrypters:
+        List of `MaskingEncrypter` instances with compatible seeds.
+    """
+    n_pairs = int(n_peers * (n_peers - 1) / 2)
+    s_keys = [secrets.randbits(32) for _ in range(n_pairs)]
+    clients = []  # type: List[MaskingEncrypter]
+    starts = [n_peers - i - 1 for i in range(n_peers)]
+    starts = [sum(starts[:i]) for i in range(n_peers)]
+    for idx in range(n_peers):
+        pos = s_keys[starts[idx] : starts[idx] + n_peers - idx - 1]
+        neg = [s_keys[starts[j] + idx - j - 1] for j in range(idx)]
+        clients.append(
+            MaskingEncrypter(pos_masks_seeds=pos, neg_masks_seeds=neg)
+        )
+    server = MaskingDecrypter(n_peers=n_peers)
+    return server, clients
-- 
GitLab