Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit edd62dcd authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Add 'test_utils' setup routines for mock network and masking secagg.

parent 4cd7d94d
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
......@@ -38,7 +38,12 @@ from ._assertions import (
from ._convert import to_numpy
from ._gen_ssl import generate_ssl_certificates
from ._imports import make_importable
from ._network import MockNetworkClient, MockNetworkServer
from ._network import (
MockNetworkClient,
MockNetworkServer,
setup_mock_network_endpoints,
)
from ._secagg import build_secagg_controllers
from ._vectors import (
FrameworkType,
GradientsTestCase,
......
......@@ -18,9 +18,13 @@
"""Fake network communication endpoints relying on shared memory objects."""
import asyncio
import contextlib
import logging
import uuid
from typing import Dict, Mapping, Optional, Set, TypeVar, Union
from typing import (
# fmt: off
AsyncIterator, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union
)
from declearn.communication.api import NetworkClient, NetworkServer
......@@ -31,6 +35,7 @@ from declearn.messaging import Message, SerializedMessage
__all__ = [
"MockNetworkClient",
"MockNetworkServer",
"setup_mock_network_endpoints",
]
......@@ -179,3 +184,47 @@ class MockNetworkClient(NetworkClient, register=False):
) -> SerializedMessage:
# Force the use of a timeout, to prevent tests from being stuck.
return await super().recv_message(timeout=timeout or 5)
@contextlib.asynccontextmanager
async def setup_mock_network_endpoints(
n_peers: int,
port: int = 8765,
) -> AsyncIterator[Tuple[MockNetworkServer, List[MockNetworkClient]]]:
"""Instantiate, start and register mock network communication endpoints.
This is an async context manager, that returns network endpoints,
and ensures they are all properly closed upon leaving the context.
Parameters
----------
n_peers:
Number of client endpoints to instantiate.
port:
Mock port number to use.
Returns
-------
server:
`MockNetworkServer` instance to which clients are registered.
clients:
List of `MockNetworkClient` instances, registered to the server.
"""
# Instantiate the endpoints.
server = MockNetworkServer(port=port)
clients = [
MockNetworkClient(f"mock://localhost:{port}", name=f"client_{i}")
for i in range(n_peers)
]
async with contextlib.AsyncExitStack() as stack:
# Start the endpoints and ensure they will be properly closed.
await stack.enter_async_context(server) # type: ignore
for client in clients:
await stack.enter_async_context(client) # type: ignore
# Register the clients with the server.
await asyncio.gather(
server.wait_for_clients(n_peers),
*[client.register() for client in clients],
)
# Yield the started, registered endpoints.
yield server, clients
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Routine to set up some SecAgg controllers."""
import secrets
from typing import List, Tuple
from declearn.secagg.masking import MaskingDecrypter, MaskingEncrypter
__all__ = [
"build_secagg_controllers",
]
def build_secagg_controllers(
n_peers: int,
) -> Tuple[MaskingDecrypter, List[MaskingEncrypter]]:
"""Setup aligned masking-based encrypters and decrypter.
Parameters
----------
n_peers:
Number of clients for which to set up an encrypter.
Returns
-------
decrypter:
`MaskingDecrypter` instance.
encrypters:
List of `MaskingEncrypter` instances with compatible seeds.
"""
n_pairs = int(n_peers * (n_peers - 1) / 2)
s_keys = [secrets.randbits(32) for _ in range(n_pairs)]
clients = [] # type: List[MaskingEncrypter]
starts = [n_peers - i - 1 for i in range(n_peers)]
starts = [sum(starts[:i]) for i in range(n_peers)]
for idx in range(n_peers):
pos = s_keys[starts[idx] : starts[idx] + n_peers - idx - 1]
neg = [s_keys[starts[j] + idx - j - 1] for j in range(idx)]
clients.append(
MaskingEncrypter(pos_masks_seeds=pos, neg_masks_seeds=neg)
)
server = MaskingDecrypter(n_peers=n_peers)
return server, clients
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment