From edd62dcdf5b8e48e49d8252a623c9b99037503e3 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 5 Jul 2024 16:16:35 +0200 Subject: [PATCH] Add 'test_utils' setup routines for mock network and masking secagg. --- declearn/test_utils/__init__.py | 7 +++- declearn/test_utils/_network.py | 51 ++++++++++++++++++++++++++- declearn/test_utils/_secagg.py | 61 +++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 declearn/test_utils/_secagg.py diff --git a/declearn/test_utils/__init__.py b/declearn/test_utils/__init__.py index 37915bf7..f05bbb7b 100644 --- a/declearn/test_utils/__init__.py +++ b/declearn/test_utils/__init__.py @@ -38,7 +38,12 @@ from ._assertions import ( from ._convert import to_numpy from ._gen_ssl import generate_ssl_certificates from ._imports import make_importable -from ._network import MockNetworkClient, MockNetworkServer +from ._network import ( + MockNetworkClient, + MockNetworkServer, + setup_mock_network_endpoints, +) +from ._secagg import build_secagg_controllers from ._vectors import ( FrameworkType, GradientsTestCase, diff --git a/declearn/test_utils/_network.py b/declearn/test_utils/_network.py index db051532..e0ad213c 100644 --- a/declearn/test_utils/_network.py +++ b/declearn/test_utils/_network.py @@ -18,9 +18,13 @@ """Fake network communication endpoints relying on shared memory objects.""" import asyncio +import contextlib import logging import uuid -from typing import Dict, Mapping, Optional, Set, TypeVar, Union +from typing import ( + # fmt: off + AsyncIterator, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union +) from declearn.communication.api import NetworkClient, NetworkServer @@ -31,6 +35,7 @@ from declearn.messaging import Message, SerializedMessage __all__ = [ "MockNetworkClient", "MockNetworkServer", + "setup_mock_network_endpoints", ] @@ -179,3 +184,47 @@ class MockNetworkClient(NetworkClient, register=False): ) -> SerializedMessage: # Force the use of a timeout, to prevent tests from being stuck. return await super().recv_message(timeout=timeout or 5) + + +@contextlib.asynccontextmanager +async def setup_mock_network_endpoints( + n_peers: int, + port: int = 8765, +) -> AsyncIterator[Tuple[MockNetworkServer, List[MockNetworkClient]]]: + """Instantiate, start and register mock network communication endpoints. + + This is an async context manager, that returns network endpoints, + and ensures they are all properly closed upon leaving the context. + + Parameters + ---------- + n_peers: + Number of client endpoints to instantiate. + port: + Mock port number to use. + + Returns + ------- + server: + `MockNetworkServer` instance to which clients are registered. + clients: + List of `MockNetworkClient` instances, registered to the server. + """ + # Instantiate the endpoints. + server = MockNetworkServer(port=port) + clients = [ + MockNetworkClient(f"mock://localhost:{port}", name=f"client_{i}") + for i in range(n_peers) + ] + async with contextlib.AsyncExitStack() as stack: + # Start the endpoints and ensure they will be properly closed. + await stack.enter_async_context(server) # type: ignore + for client in clients: + await stack.enter_async_context(client) # type: ignore + # Register the clients with the server. + await asyncio.gather( + server.wait_for_clients(n_peers), + *[client.register() for client in clients], + ) + # Yield the started, registered endpoints. + yield server, clients diff --git a/declearn/test_utils/_secagg.py b/declearn/test_utils/_secagg.py new file mode 100644 index 00000000..73c6620f --- /dev/null +++ b/declearn/test_utils/_secagg.py @@ -0,0 +1,61 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Routine to set up some SecAgg controllers.""" + +import secrets +from typing import List, Tuple + + +from declearn.secagg.masking import MaskingDecrypter, MaskingEncrypter + +__all__ = [ + "build_secagg_controllers", +] + + +def build_secagg_controllers( + n_peers: int, +) -> Tuple[MaskingDecrypter, List[MaskingEncrypter]]: + """Setup aligned masking-based encrypters and decrypter. + + Parameters + ---------- + n_peers: + Number of clients for which to set up an encrypter. + + Returns + ------- + decrypter: + `MaskingDecrypter` instance. + encrypters: + List of `MaskingEncrypter` instances with compatible seeds. + """ + n_pairs = int(n_peers * (n_peers - 1) / 2) + s_keys = [secrets.randbits(32) for _ in range(n_pairs)] + clients = [] # type: List[MaskingEncrypter] + starts = [n_peers - i - 1 for i in range(n_peers)] + starts = [sum(starts[:i]) for i in range(n_peers)] + for idx in range(n_peers): + pos = s_keys[starts[idx] : starts[idx] + n_peers - idx - 1] + neg = [s_keys[starts[j] + idx - j - 1] for j in range(idx)] + clients.append( + MaskingEncrypter(pos_masks_seeds=pos, neg_masks_seeds=neg) + ) + server = MaskingDecrypter(n_peers=n_peers) + return server, clients -- GitLab