diff --git a/declearn/fairness/api/__init__.py b/declearn/fairness/api/__init__.py index 832b9e11a0819b54fb22aadec8d072a6d9598bdf..4c76846005a192d8075b858c2b3a67cfd75e6eb3 100644 --- a/declearn/fairness/api/__init__.py +++ b/declearn/fairness/api/__init__.py @@ -17,15 +17,5 @@ """Draft API for Fairness-aware Federated Learning algorithms.""" -from ._messages import ( - FairnessAccuracy, - FairnessCounts, - FairnessGroups, - SecaggFairnessAccuracy, - SecaggFairnessCounts, -) -from ._controllers import ( - FairnessControllerClient, - FairnessControllerServer, - FairnessSetupQuery, -) +from ._client import FairnessControllerClient +from ._server import FairnessControllerServer diff --git a/declearn/fairness/api/_client.py b/declearn/fairness/api/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd98ac5e50372f1f9d603c0185224701f5dc1bb --- /dev/null +++ b/declearn/fairness/api/_client.py @@ -0,0 +1,334 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client-side ABC for fairness-aware federated learning controllers.""" + +import abc +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union + +import numpy as np + +from declearn.communication.api import NetworkClient +from declearn.communication.utils import verify_server_message_validity +from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset +from declearn.messaging import ( + Error, + FairnessCounts, + FairnessGroups, + FairnessQuery, + FairnessReply, + FairnessSetupQuery, +) +from declearn.secagg.api import Encrypter +from declearn.secagg.messaging import ( + SecaggFairnessCounts, + SecaggFairnessReply, +) +from declearn.training import TrainingManager +from declearn.utils import ( + access_registered, + create_types_registry, + register_type, +) + +__all__ = [ + "FairnessControllerClient", +] + + +@create_types_registry(name="FairnessControllerClient") +class FairnessControllerClient(metaclass=abc.ABCMeta): + """Abstract base class for client-side fairness controllers.""" + + algorithm: ClassVar[str] + """Name of the fairness-enforcing algorithm. + + This name should be unique across 'FairnessControllerClient' classes, + and shared with a unique paired 'FairnessControllerServer'. It is used + for type-registration and to enable instantiating a client controller + based on server-emitted instructions in a federated setting. + """ + + def __init_subclass__( + cls, + register: bool = True, + ) -> None: + """Automatically type-register subclasses.""" + if register: + register_type(cls, cls.algorithm, group="FairnessControllerClient") + + def __init__( + self, + manager: TrainingManager, + ) -> None: + """Instantiate the client-side fairness controller. + + Parameters + ---------- + manager: + `TrainingManager` instance wrapping the model being trained + and its training dataset (that must be a `FairnessDataset`). + """ + if not isinstance(manager.train_data, FairnessDataset): + raise TypeError( + "Cannot set up fairness without a 'FairnessDataset' " + "as training dataset." + ) + self.manager = manager + self.computer = FairnessAccuracyComputer(manager.train_data) + self.groups = [] # type: List[Tuple[Any, ...]] + + @staticmethod + def from_setup_query( + query: FairnessSetupQuery, + manager: TrainingManager, + ) -> "FairnessControllerClient": + """Instantiate a controller from a server-emitted query. + + Parameters + ---------- + query: + `FairnessSetupQuery` received from the server. + manager: + `TrainingManager` wrapping the model to train. + + Returns + ------- + controller: + `FairnessControllerClient` instance, the type and parameters + of which depend on the input `query`, that wraps `manager`. + """ + try: + cls = access_registered( + name=query.algorithm, group="FairnessControllerClient" + ) + assert issubclass(cls, FairnessControllerClient) + except Exception as exc: + raise ValueError( + "Failed to retrieve a 'FairnessControllerClient' class " + "matching the input 'FairnessSetupQuery' message." + ) from exc + return cls(manager=manager, **query.params) + + async def setup_fairness( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + """Participate in a routine to initialize fairness-aware learning. + + This routine has the following structure: + + - Exchange with the server to agree on an ordered list of sensitive + groups defined by the interesection of 1+ sensitive attributes + and (opt.) a classification target label. + - Send (encrypted) group-wise training sample counts, that the server + is to (secure-)aggregate. + - Perform any additional actions specific to the algorithm in use. + - On the client side, optionally alter the `TrainingManager` used. + - On the server side, optionally alter the `Aggregator` used. + + Parameters + ---------- + netwk: + NetworkClient endpoint, registered to a server. + secagg: + Optional SecAgg encryption controller. + """ + # Share sensitive groups definitions and received an ordered list. + self.groups = await self._exchange_sensitive_groups_list(netwk) + # Send group-wise sample counts for the server to (secure-)aggregate. + await self._send_sensitive_groups_counts(netwk, secagg) + # Run additional algorithm-specific setup steps. + await self.finalize_fairness_setup(netwk, secagg) + + async def _exchange_sensitive_groups_list( + self, + netwk: NetworkClient, + ) -> List[Tuple[Any, ...]]: + """Exhange sensitive groups definitions and return a unified list.""" + # Gather local sensitive groups and their sample counts. + counts = self.computer.counts + groups = list(counts) + # Share them and receive a unified, ordered list of groups. + await netwk.send_message(FairnessGroups(groups=groups)) + received = await netwk.recv_message() + message = await verify_server_message_validity( + netwk, received, expected=FairnessGroups + ) + return message.groups + + async def _send_sensitive_groups_counts( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + """Send (opt. encrypted) group-wise sample counts to the server.""" + counts = self.computer.counts + reply = FairnessCounts([counts.get(group, 0) for group in self.groups]) + if secagg is None: + await netwk.send_message(reply) + else: + await netwk.send_message( + SecaggFairnessCounts.from_cleartext_message(reply, secagg) + ) + + @abc.abstractmethod + async def finalize_fairness_setup( + self, + netwk: NetworkClient, + secagg: Optional[Encrypter], + ) -> None: + """Finalize the fairness setup routine. + + This method is called as part of `setup_fairness`, and should + be defined by concrete subclasses to implement setup behavior + once the initial echange of sensitive group definitions and + sample counts has been performed. + + Parameters + ---------- + netwk: + NetworkClient endpoint, registered to a server. + secagg: + Optional SecAgg encryption controller. + """ + + async def fairness_round( + self, + netwk: NetworkClient, + query: FairnessQuery, + secagg: Optional[Encrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + """Participate in a round of actions to enforce fairness. + + Parameters + ---------- + netwk: + NetworkClient endpoint instance, connected to a server. + query: + `FairnessQuery` message to participate in a fairness round. + secagg: + Optional SecAgg encryption controller. + + Returns + ------- + metrics: + Computed local fairness(-related) metrics computed as part + of this routine, as a dict mapping scalar or numpy array + values with their name. + """ + try: + values = await self._compute_and_share_fairness_measures( + netwk, query, secagg + ) + except Exception as exc: + error = f"Error encountered in fairness round: {repr(exc)}" + self.manager.logger.error(error) + await netwk.send_message(Error(error)) + raise RuntimeError(error) from exc + # Run additional algorithm-specific steps. + return await self.finalize_fairness_round(netwk, values, secagg) + + async def _compute_and_share_fairness_measures( + self, + netwk: NetworkClient, + query: FairnessQuery, + secagg: Optional[Encrypter], + ) -> List[float]: + """Compute, share (encrypted) and return fairness measures.""" + # Optionally update the wrapped model's weights. + if query.weights is not None: + self.manager.model.set_weights(query.weights, trainable=True) + # Compute, opt. encrypt and share fairness-related metrics. + values = self.compute_fairness_measures( + query.batch_size, query.n_batch, query.thresh + ) + reply = FairnessReply(values=values) + if secagg is None: + await netwk.send_message(reply) + else: + await netwk.send_message( + SecaggFairnessReply.from_cleartext_message(reply, secagg) + ) + # Return computed values. + return values + + @abc.abstractmethod + def compute_fairness_measures( + self, + batch_size: int, + n_batch: Optional[int] = None, + thresh: Optional[float] = None, + ) -> List[float]: + """Compute fairness measures based on a received query. + + By default, compute and return group-wise accuracy metrics, + weighted by group-wise sample counts. This may be modified + by algorithm-specific subclasses depending on algorithms' + needs. + + Parameters + ---------- + batch_size: + Number of samples per batch when computing predictions. + n_batch: + Optional maximum number of batches to draw per category. + If None, use the entire wrapped dataset. + thresh: + Optional binarization threshold for binary classification + models' output scores. If None, use 0.5 by default, or 0.0 + for `SklearnSGDModel` instances. + Unused for multinomial classifiers (argmax over scores). + + Returns + ------- + values: + Computed values, as a deterministic-length ordered list + of float values. + """ + + @abc.abstractmethod + async def finalize_fairness_round( + self, + netwk: NetworkClient, + values: List[float], + secagg: Optional[Encrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + """Take actions to enforce fairness. + + This method is designed to be called after an initial query + has been received and responded to, resulting in computing + and sharing fairness(-related) metrics. + + Parameters + ---------- + netwk: + NetworkClient endpoint instance, connected to a server. + values: + List of locally-computed evaluation metrics, already shared + with the server for their (secure-)aggregation. + secagg: + Optional SecAgg encryption controller. + + Returns + ------- + metrics: + Computed local fairness(-related) metrics computed as part + of this routine, as a dict mapping scalar or numpy array + values with their name. + """ diff --git a/declearn/fairness/api/_controllers.py b/declearn/fairness/api/_controllers.py deleted file mode 100644 index 52277c0a9303c3b1d05dbebfb494a1d67b1b1b91..0000000000000000000000000000000000000000 --- a/declearn/fairness/api/_controllers.py +++ /dev/null @@ -1,478 +0,0 @@ -# coding: utf-8 - -# Copyright 2023 Inria (Institut National de Recherche en Informatique -# et Automatique) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Draft API for Fairness-aware Federated Learning.""" - -import abc -import dataclasses -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np - -from declearn.aggregator import Aggregator -from declearn.communication.api import NetworkClient, NetworkServer -from declearn.communication.utils import ( - verify_client_messages_validity, - verify_server_message_validity, -) -from declearn.fairness.api._messages import ( - FairnessCounts, - FairnessGroups, - SecaggFairnessCounts, -) -from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset -from declearn.messaging import Error, FairnessQuery, FairnessReply, Message -from declearn.secagg.api import Decrypter, Encrypter -from declearn.secagg.messaging import ( - aggregate_secagg_messages, - SecaggFairnessReply, -) -from declearn.training import TrainingManager - -__all__ = [ - "FairnessControllerClient", - "FairnessControllerServer", - "FairnessSetupQuery", -] - - -class FairnessControllerClient(metaclass=abc.ABCMeta): - """Abstract base class for client-side fairness controllers.""" - - def __init__( - self, - ) -> None: - """Instantiate the client-side fairness controller.""" - self.groups = [] # type: List[Tuple[Any, ...]] - - async def setup_fairness( - self, - netwk: NetworkClient, - manager: TrainingManager, - secagg: Optional[Encrypter], - params: Dict[str, Any], - ) -> TrainingManager: - """Participate in a routine to initialize fairness-aware learning. - - This routine has the following structure: - - - Exchange with the server to agree on an ordered list of sensitive - groups defined by the interesection of 1+ sensitive attributes - and (opt.) a classification target label. - - Send (encrypted) group-wise training sample counts, that the server - is to (secure-)aggregate. - - Perform any additional actions specific to the algorithm in use. - - On the client side, optionally alter the `TrainingManager` used. - - On the server side, optionally alter the `Aggregator` used. - - Parameters - ---------- - netwk: - NetworkClient endpoint, registered to a server. - manager: - TrainingManager instance that was set up notwithstanding fairness. - secagg: - Optional SecAgg encryption controller. - params: - Dict of algorithm-specific keyword arguments received from - the server as part of the query that triggered this routine. - - Warns - ----- - RuntimeWarning - If the returned training manager differs from the input one. - - Returns - ------- - manager: - `TrainingManager` instance to use in the FL process, that may - or may not have been altered compared with the input one. - """ - # Verify that a training 'FairnessDataset' is available. - if not isinstance(manager.train_data, FairnessDataset): - msg = "Cannot set up fairness without a 'FairnessDataset'." - await netwk.send_message(Error(msg)) - raise TypeError(msg) - # Gather local sensitive groups and their sample counts. - counts = manager.train_data.get_sensitive_group_counts() - groups = list(counts) - # Share them and receive a unified, ordered list of groups. - await netwk.send_message(FairnessGroups(groups=groups)) - received = await netwk.recv_message() - message = await verify_server_message_validity( - netwk, received, expected=FairnessGroups - ) - self.groups = message.groups - # Sort and fill out sample counts, opt. encrypt them and send them. - reply = FairnessCounts([counts.get(group, 0) for group in self.groups]) - if secagg is None: - await netwk.send_message(reply) - else: - await netwk.send_message( - SecaggFairnessCounts.from_cleartext_message(reply, secagg) - ) - # Run additional algorithm-specific setup steps. - return await self.finalize_fairness_setup( - netwk, manager, secagg, params - ) - - @abc.abstractmethod - async def finalize_fairness_setup( - self, - netwk: NetworkClient, - manager: TrainingManager, - secagg: Optional[Encrypter], - params: Dict[str, Any], - ) -> TrainingManager: - """Finalize the fairness setup routine and return an Aggregator. - - This method is called as part of `setup_fairness`, and should - be defined by concrete subclasses to implement setup behavior - once the initial query/reply messages have been exchanged. - - The returned `TrainingManager` may either be the input `manager` - or a new or modified version of it, depending on the needs of - the fairness-aware federated learning process being implemented. - - Parameters - ---------- - netwk: - NetworkClient endpoint, registered to a server. - manager: - TrainingManager instance that was set up notwithstanding fairness. - secagg: - Optional SecAgg encryption controller. - params: - Dict of algorithm-specific keyword arguments received from - the server as part of the query that triggered this routine. - - Warns - ----- - RuntimeWarning - If the returned training manager differs from the input one. - - Returns - ------- - manager: - `TrainingManager` instance to use in the FL process, that may - or may not have been altered compared with the input one. - """ - - async def fairness_round( - self, - netwk: NetworkClient, - query: FairnessQuery, - manager: TrainingManager, - secagg: Optional[Encrypter], - ) -> None: - """Participate in a round of actions to enforce fairness. - - Parameters - ---------- - netwk: - NetworkClient endpoint instance, connected to a server. - query: - `FairnessQuery` message to participate in a fairness round. - manager: - TrainingManager instance holding the local model, optimizer, etc. - This method may (and usually does) have side effects on this. - secagg: - Optional SecAgg encryption controller. - """ - values = self.compute_fairness_measures(query, manager) - reply = FairnessReply(values=values) - if secagg is None: - await netwk.send_message(reply) - else: - await netwk.send_message( - SecaggFairnessReply.from_cleartext_message(reply, secagg) - ) - await self.finalize_fairness_round(netwk, values, manager, secagg) - - def compute_fairness_measures( - self, - query: FairnessQuery, - manager: TrainingManager, - ) -> List[float]: - """Compute fairness measures based on a received query. - - By default, compute and return group-wise accuracy metrics, - weighted by group-wise sample counts. This may be modified - by algorithm-specific subclasses depending on algorithms' - needs. - - Parameters - ---------- - query: - `FairnessQuery` message with computational effort constraints, - and optionally model weights to assign before evaluation. - manager: - TrainingManager instance holding the model to evaluate and the - training dataset on which to do so. - - Returns - ------- - values: - Computed values, as a deterministic-length ordered list - of float values. - """ - assert isinstance(manager.train_data, FairnessDataset) - if query.weights is not None: - manager.model.set_weights(query.weights, trainable=True) - # Compute group-wise accuracy metrics. - computer = FairnessAccuracyComputer(manager.train_data) - accuracy = computer.compute_groupwise_accuracy( - model=manager.model, - batch_size=query.batch_size, - n_batch=query.n_batch, - thresh=query.thresh, - ) - # Scale computed accuracy metrics by sample counts. - accuracy = { - key: val * computer.counts[key] for key, val in accuracy.items() - } - # Gather ordered values (filling-in groups without samples). - return [accuracy.get(group, 0.0) for group in self.groups] - - @abc.abstractmethod - async def finalize_fairness_round( - self, - netwk: NetworkClient, - values: List[float], - manager: TrainingManager, - secagg: Optional[Encrypter], - ) -> None: - """Take actions to enforce fairness. - - This method is designed to be called after an initial query - has been received and responded to, resulting in computing - and sharing fairness(-related) metrics. - - Parameters - ---------- - netwk: - NetworkClient endpoint instance, connected to a server. - values: - List of locally-computed evaluation metrics, already shared - with the server for their (secure-)aggregation. - manager: - TrainingManager instance holding the local model, optimizer, etc. - This method may (and usually does) have side effects on this. - secagg: - Optional SecAgg encryption controller. - """ - - -@dataclasses.dataclass -class FairnessSetupQuery(Message, register=False, metaclass=abc.ABCMeta): - """ABC message for all Fairness setup init requests. - - This message should be subclassed into algorithm-specific messages. - """ - - @abc.abstractmethod - def instantiate_controller( - self, - ) -> FairnessControllerClient: - """Instantiate a `FairnessControllerClient` matching this query.""" - - def get_setup_params( - self, - ) -> Dict[str, Any]: - """Return a dict of parameters to pass to the client setup routine.""" - return {} - - -class FairnessControllerServer(metaclass=abc.ABCMeta): - """Abstract base class for server-side fairness controllers.""" - - def __init__( - self, - f_type: str, - f_args: Optional[Dict[str, Any]], - ) -> None: - """Instantiate the server-side fairness controller. - - Parameters - ---------- - f_type: - Name of the fairness function to evaluate and optimize. - f_args: - Optional dict of keyword arguments to the fairness function. - """ - self.f_type = f_type - self.f_args = f_args or {} - self.groups = [] # type: List[Tuple[Any, ...]] - - async def setup_fairness( - self, - netwk: NetworkServer, - aggregator: Aggregator, - secagg: Optional[Decrypter], - ) -> Aggregator: - """Orchestrate a routine to initialize fairness-aware learning. - - This routine has the following structure: - - - Send a setup query to clients, the type of which depends - on the actual fairness-enforcing algorithm used. - - Exchange with clients to agree on an ordered list of sensitive - groups defined by the interesection of 1+ sensitive attributes - and (opt.) a classification target label. - - Receive and (secure-)aggregate group-wise sample counts across - clients' training dataset. - - Perform any additional actions specific to the algorithm in use. - - On the server side, optionally alter the `Aggregator` used. - - On the client side, optionally alter the `TrainingManager` used. - - Parameters - ---------- - netwk: - NetworkServer endpoint, to which clients are registered. - aggregator: - Aggregator instance that was set up notwithstanding fairness. - secagg: - Optional SecAgg decryption controller. - - Warns - ----- - RuntimeWarning - If the returned aggregator differs from the input one. - - Returns - ------- - aggregator: - `Aggregator` instance to use in the FL process, that may - or may not have been altered compared with the input one. - """ - # Send a setup query to all clients. - query = self.prepare_fairness_setup_query() - await netwk.broadcast_message(query) - # Receive, aggregate, assign and send back sensitive group definitions. - await self._exchange_sensitive_groups_list(netwk) - # Wait for group-wise sample counts from clients. - received = await netwk.wait_for_messages() - # When SecAgg is not used, expect cleartext group-wise counts. - if secagg is None: - replies = await verify_client_messages_validity( - netwk, received, expected=FairnessCounts - ) - counts = self._aggregate_cleartext_counts(replies) - # When SecAgg is used, expect and secure-aggregate encrypted counts. - else: - sec_rep = await verify_client_messages_validity( - netwk, received, expected=SecaggFairnessCounts - ) - counts = aggregate_secagg_messages(sec_rep, secagg).counts - # Run additional algorithm-specific setup steps. - return await self.finalize_fairness_setup(netwk, counts, aggregator) - - def _aggregate_cleartext_counts( - self, - messages: Dict[str, FairnessCounts], - ) -> List[int]: - """Sum group-wise sample counts received from clients.""" - counts = np.zeros(len(self.groups), dtype="uint64") - for message in messages.values(): - counts += np.asarray(message.counts, dtype="uint64") - return counts.tolist() - - async def _exchange_sensitive_groups_list( - self, - netwk: NetworkServer, - ) -> None: - """Receive, aggregate, assign and share sensitive group definitions.""" - received = await netwk.wait_for_messages() - # Verify and deserialize client-wise sensitive group definitions. - messages = await verify_client_messages_validity( - netwk, received, expected=FairnessGroups - ) - # Gather the sorted union of all existing definitions. - unique = {group for msg in messages.values() for group in msg.groups} - self.groups = sorted(list(unique)) - # Send it to clients, and expect their reply (encrypted counts). - await netwk.broadcast_message(FairnessGroups(groups=self.groups)) - - @abc.abstractmethod - def prepare_fairness_setup_query( - self, - ) -> FairnessSetupQuery: - """Return a request to setup fairness, broadcastable to clients. - - Returns - ------- - message: - `FairnessSetupQuery` subclass instance to be sent to clients - in order to trigger the Fairness setup protocol. - """ - - @abc.abstractmethod - async def finalize_fairness_setup( - self, - netwk: NetworkServer, - counts: List[int], - aggregator: Aggregator, - ) -> Aggregator: - """Finalize the fairness setup routine and return an Aggregator. - - This method is called as part of `setup_fairness`, and should - be defined by concrete subclasses to implement setup behavior - once the initial query/reply messages have been exchanged. - - The returned `Aggregator` may either be the input `aggregator` - or a new or modified version of it, depending on the needs of - the fairness-aware federated learning process being implemented. - - Warns - ----- - RuntimeWarning - If the returned aggregator differs from the input one. - - Returns - ------- - aggregator: - `Aggregator` instance to use in the FL process, that may - or may not have been altered compared with the input one. - """ - - @abc.abstractmethod - async def finalize_fairness_round( - self, - round_i: int, - values: List[float], - netwk: NetworkServer, - secagg: Optional[Decrypter], - ) -> None: - """Orchestrate a round of actions to enforce fairness. - - This method is designed to be called after an initial query - has been sent and responded to by clients, resulting in the - federated computation of fairness(-related) metrics. - - Parameters - ---------- - round_i: - Index of the current round (reflecting that of an upcoming - training round). - values: - Aggregated metrics resulting from the fairness evaluation - run by clients at this round. - netwk: - NetworkServer endpoint instance, to which clients are registered. - secagg: - Optional SecAgg decryption controller. - """ diff --git a/declearn/fairness/api/_messages.py b/declearn/fairness/api/_messages.py deleted file mode 100644 index 7beb0785ddcfb076ad9bc37009168a5da5dd66ad..0000000000000000000000000000000000000000 --- a/declearn/fairness/api/_messages.py +++ /dev/null @@ -1,168 +0,0 @@ -# coding: utf-8 - -# Copyright 2023 Inria (Institut National de Recherche en Informatique -# et Automatique) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""API messages for fairness-aware federated learning setup and rounds.""" - -import dataclasses -from typing import Any, List, Tuple - -from typing_extensions import Self # future: import from typing (py >=3.11) - -from declearn.messaging import Message -from declearn.secagg.api import Decrypter, Encrypter -from declearn.secagg.messaging import SecaggMessage - -__all__ = [ - "FairnessAccuracy", - "FairnessCounts", - "FairnessGroups", - "SecaggFairnessAccuracy", - "SecaggFairnessCounts", -] - - -@dataclasses.dataclass -class FairnessAccuracy(Message): - """Message for client-emitted model accuracy across sensitive groups. - - Fields - ------ - values: - List of group-wise accuracy values, ordered based - on an agreed-upon sorted list of sensitive groups. - """ - - values: List[float] - - typekey = "fairness-accuracy" - - -@dataclasses.dataclass -class SecaggFairnessAccuracy(SecaggMessage[FairnessAccuracy]): - """SecAgg counterpart of the 'FairnessAccuracy' message class.""" - - values: List[int] - - typekey = "secagg-fairness-accuracy" - - @classmethod - def from_cleartext_message( - cls, - cleartext: FairnessAccuracy, - encrypter: Encrypter, - ) -> Self: - values = [encrypter.encrypt_float(val) for val in cleartext.values] - return cls(values=values) - - def decrypt_wrapped_message( - self, - decrypter: Decrypter, - ) -> FairnessAccuracy: - values = [decrypter.decrypt_float(val) for val in self.values] - return FairnessAccuracy(values=values) - - def aggregate( - self, - other: Self, - decrypter: Decrypter, - ) -> Self: - values = [ - decrypter.sum_encrypted([v_a, v_b]) - for v_a, v_b in zip(self.values, other.values) - ] - return self.__class__(values=values) - - -@dataclasses.dataclass -class FairnessCounts(Message): - """Message for client-emitted sample counts across sensitive groups. - - Fields - ------ - counts: - List of group-wise sample counts, ordered based on - an agreed-upon sorted list of sensitive groups. - """ - - counts: List[int] - - typekey = "fairness-counts" - - -@dataclasses.dataclass -class SecaggFairnessCounts(SecaggMessage[FairnessCounts]): - """SecAgg counterpart of the 'FairnessCounts' message class.""" - - counts: List[int] - - typekey = "secagg-fairness-counts" - - @classmethod - def from_cleartext_message( - cls, - cleartext: FairnessCounts, - encrypter: Encrypter, - ) -> Self: - counts = [encrypter.encrypt_uint(val) for val in cleartext.counts] - return cls(counts=counts) - - def decrypt_wrapped_message( - self, - decrypter: Decrypter, - ) -> FairnessCounts: - counts = [decrypter.decrypt_uint(val) for val in self.counts] - return FairnessCounts(counts=counts) - - def aggregate( - self, - other: Self, - decrypter: Decrypter, - ) -> Self: - counts = [ - decrypter.sum_encrypted([v_a, v_b]) - for v_a, v_b in zip(self.counts, other.counts) - ] - return self.__class__(counts=counts) - - -@dataclasses.dataclass -class FairnessGroups(Message): - """Message to exchange a list of unique sensitive group definitions. - - This message may be exchanged both ways, with clients sharing the - list of groups for which they have samples and the server sharing - back a unified, sorted list of all sensitive groups across clients. - - Fields - ------ - groups: - List of sensitive group definitions, defined by tuples of values - corresponding to those of one or more sensitive attributes and - (optionally) a target label. - """ - - groups: List[Tuple[Any, ...]] - - typekey = "fairness-groups" - - @classmethod - def from_kwargs( - cls, - **kwargs: Any, - ) -> Self: - kwargs["groups"] = [tuple(group) for group in kwargs["groups"]] - return super().from_kwargs(**kwargs) diff --git a/declearn/fairness/api/_server.py b/declearn/fairness/api/_server.py new file mode 100644 index 0000000000000000000000000000000000000000..6f893f9fe825e9a8b37d2051725f39debf3b8117 --- /dev/null +++ b/declearn/fairness/api/_server.py @@ -0,0 +1,272 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server-side ABC for fairness-aware federated learning controllers.""" + +import abc +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union + +import numpy as np + +from declearn.aggregator import Aggregator +from declearn.communication.api import NetworkServer +from declearn.communication.utils import verify_client_messages_validity +from declearn.messaging import ( + FairnessCounts, + FairnessGroups, + FairnessSetupQuery, + SerializedMessage, +) + +from declearn.secagg.api import Decrypter +from declearn.secagg.messaging import ( + aggregate_secagg_messages, + SecaggFairnessCounts, +) +from declearn.utils import create_types_registry, register_type + +__all__ = [ + "FairnessControllerServer", +] + + +@create_types_registry(name="FairnessControllerServer") +class FairnessControllerServer(metaclass=abc.ABCMeta): + """Abstract base class for server-side fairness controllers.""" + + algorithm: ClassVar[str] + """Name of the fairness-enforcing algorithm. + + This name should be unique across 'FairnessControllerServer' classes, + and shared with a unique paired 'FairnessControllerClient'. It is used + for type-registration and to enable instructing clients to instantiate + a controller matching that chosen by the server in a federated setting. + """ + + def __init_subclass__( + cls, + register: bool = True, + ) -> None: + """Automatically type-register subclasses.""" + if register: + register_type(cls, cls.algorithm, group="FairnessControllerServer") + + def __init__( + self, + f_type: str, + f_args: Optional[Dict[str, Any]], + ) -> None: + """Instantiate the server-side fairness controller. + + Parameters + ---------- + f_type: + Name of the fairness function to evaluate and optimize. + f_args: + Optional dict of keyword arguments to the fairness function. + """ + self.f_type = f_type + self.f_args = f_args or {} + self.groups = [] # type: List[Tuple[Any, ...]] + + async def setup_fairness( + self, + netwk: NetworkServer, + aggregator: Aggregator, + secagg: Optional[Decrypter], + ) -> Aggregator: + """Orchestrate a routine to initialize fairness-aware learning. + + This routine has the following structure: + + - Send a setup query to clients, resulting in the instantiation + of client-side controllers matching this one. + - Exchange with clients to agree on an ordered list of sensitive + groups defined by the interesection of 1+ sensitive attributes + and (opt.) a classification target label. + - Receive and (secure-)aggregate group-wise sample counts across + clients' training dataset. + - Perform any additional actions specific to the algorithm in use. + - On the server side, optionally alter the `Aggregator` used. + - On the client side, optionally alter the `TrainingManager` used. + + Parameters + ---------- + netwk: + NetworkServer endpoint, to which clients are registered. + aggregator: + Aggregator instance that was set up notwithstanding fairness. + secagg: + Optional SecAgg decryption controller. + + Warns + ----- + RuntimeWarning + If the returned aggregator differs from the input one. + + Returns + ------- + aggregator: + `Aggregator` instance to use in the FL process, that may + or may not have been altered compared with the input one. + """ + # Send a setup query to all clients. + query = self.prepare_fairness_setup_query() + await netwk.broadcast_message(query) + # Receive, aggregate, assign and send back sensitive group definitions. + self.groups = await self._exchange_sensitive_groups_list(netwk) + # Receive, (secure-)aggregate and return group-wise sample counts. + counts = await self._aggregate_sensitive_groups_counts(netwk, secagg) + # Run additional algorithm-specific setup steps. + return await self.finalize_fairness_setup(netwk, counts, aggregator) + + def prepare_fairness_setup_query( + self, + ) -> FairnessSetupQuery: + """Return a request to setup fairness, broadcastable to clients. + + Returns + ------- + message: + `FairnessSetupQuery` instance to be sent to clients in order + to trigger the Fairness setup protocol. + """ + return FairnessSetupQuery(algorithm=self.algorithm) + + @staticmethod + async def _exchange_sensitive_groups_list( + netwk: NetworkServer, + ) -> List[Tuple[Any, ...]]: + """Receive, aggregate, share and return sensitive group definitions.""" + received = await netwk.wait_for_messages() + # Verify and deserialize client-wise sensitive group definitions. + messages = await verify_client_messages_validity( + netwk, received, expected=FairnessGroups + ) + # Gather the sorted union of all existing definitions. + unique = {group for msg in messages.values() for group in msg.groups} + groups = sorted(list(unique)) + # Send it to clients, and expect their reply (encrypted counts). + await netwk.broadcast_message(FairnessGroups(groups=groups)) + return groups + + async def _aggregate_sensitive_groups_counts( + self, + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> List[int]: + """Receive, (secure-)aggregate and return group-wise sample counts.""" + received = await netwk.wait_for_messages() + if secagg is None: + return await self._aggregate_sensitive_groups_counts_cleartext( + netwk=netwk, received=received, n_groups=len(self.groups) + ) + return await self._aggregate_sensitive_groups_counts_encrypted( + netwk=netwk, received=received, decrypter=secagg + ) + + @staticmethod + async def _aggregate_sensitive_groups_counts_cleartext( + netwk: NetworkServer, + received: Dict[str, SerializedMessage], + n_groups: int, + ) -> List[int]: + """Deserialize and aggregate cleartext group-wise counts.""" + replies = await verify_client_messages_validity( + netwk, received, expected=FairnessCounts + ) + counts = np.zeros(n_groups, dtype="uint64") + for message in replies.values(): + counts = counts + np.asarray(message.counts, dtype="uint64") + return counts.tolist() + + @staticmethod + async def _aggregate_sensitive_groups_counts_encrypted( + netwk: NetworkServer, + received: Dict[str, SerializedMessage], + decrypter: Decrypter, + ) -> List[int]: + """Deserialize and secure-aggregate encrypted group-wise counts.""" + replies = await verify_client_messages_validity( + netwk, received, expected=SecaggFairnessCounts + ) + aggregated = aggregate_secagg_messages(replies, decrypter) + return aggregated.counts + + @abc.abstractmethod + async def finalize_fairness_setup( + self, + netwk: NetworkServer, + counts: List[int], + aggregator: Aggregator, + ) -> Aggregator: + """Finalize the fairness setup routine and return an Aggregator. + + This method is called as part of `setup_fairness`, and should + be defined by concrete subclasses to implement setup behavior + once the initial query/reply messages have been exchanged. + + The returned `Aggregator` may either be the input `aggregator` + or a new or modified version of it, depending on the needs of + the fairness-aware federated learning process being implemented. + + Warns + ----- + RuntimeWarning + If the returned aggregator differs from the input one. + + Returns + ------- + aggregator: + `Aggregator` instance to use in the FL process, that may + or may not have been altered compared with the input one. + """ + + @abc.abstractmethod + async def finalize_fairness_round( + self, + round_i: int, + values: List[float], + netwk: NetworkServer, + secagg: Optional[Decrypter], + ) -> Dict[str, Union[float, np.ndarray]]: + """Orchestrate a round of actions to enforce fairness. + + This method is designed to be called after an initial query + has been sent and responded to by clients, resulting in the + federated computation of fairness(-related) metrics. + + Parameters + ---------- + round_i: + Index of the current round (reflecting that of an upcoming + training round). + values: + Aggregated metrics resulting from the fairness evaluation + run by clients at this round. + netwk: + NetworkServer endpoint instance, to which clients are registered. + secagg: + Optional SecAgg decryption controller. + + Returns + ------- + metrics: + Computed local fairness(-related) metrics computed as part + of this routine, as a dict mapping scalar or numpy array + values with their name. + """ diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 7770c4fff14cd5386603a1da3c11cc3752ffd6f8..a345ac19db5a8d1bfa551c29a6cd6ed409c29991 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -33,16 +33,13 @@ from declearn.communication.utils import ( verify_server_message_validity, ) from declearn.dataset import Dataset, load_dataset_from_json -from declearn.fairness.api import ( - FairnessControllerClient, - FairnessSetupQuery, -) +from declearn.fairness.api import FairnessControllerClient from declearn.main.utils import Checkpointer from declearn.messaging import Message, SerializedMessage from declearn.training import TrainingManager from declearn.secagg import parse_secagg_config_client from declearn.secagg.api import Encrypter, SecaggConfigClient, SecaggSetupQuery -from declearn.secagg.messaging import SecaggEvaluationReply, SecaggTrainReply +from declearn.secagg import messaging as secagg_messaging from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger @@ -457,27 +454,18 @@ class FederatedClient: and should never be called in another context. """ assert self.trainmanager is not None - # Parse the serialized FairnessSetupQuery. - try: - received = await self.netwk.recv_message() - query = await verify_server_message_validity( - netwk=self.netwk, - received=received, - expected=FairnessSetupQuery, # type: ignore[type-abstract] - ) - except Exception as exc: - error = "Failed to parse fairness setup query." - self.logger.critical(error) - await self.netwk.send_message(messaging.Error(error)) - raise RuntimeError(error) from exc + # Await and deserialize a FairnessSetupQuery. + received = await self.netwk.recv_message() + query = await verify_server_message_validity( + self.netwk, received, expected=messaging.FairnessSetupQuery + ) # Instantiate a FairnessControllerClient and run its setup routine. try: - self.fairness = query.instantiate_controller() - self.trainmanager = await self.fairness.setup_fairness( - netwk=self.netwk, - manager=self.trainmanager, - secagg=self._encrypter, - params=query.get_setup_params(), + self.fairness = FairnessControllerClient.from_setup_query( + query=query, manager=self.trainmanager + ) + await self.fairness.setup_fairness( + netwk=self.netwk, secagg=self._encrypter ) except Exception as exc: error = ( @@ -562,7 +550,7 @@ class FederatedClient: if self._encrypter is not None and isinstance( reply, messaging.TrainReply ): - reply = SecaggTrainReply.from_cleartext_message( + reply = secagg_messaging.SecaggTrainReply.from_cleartext_message( cleartext=reply, encrypter=self._encrypter ) # Send training results (or error message) to the server. @@ -613,7 +601,8 @@ class FederatedClient: reply.metrics.clear() # Optionally SecAgg-encrypt results. if self._encrypter is not None: - reply = SecaggEvaluationReply.from_cleartext_message( + msg_cls = secagg_messaging.SecaggEvaluationReply + reply = msg_cls.from_cleartext_message( cleartext=reply, encrypter=self._encrypter ) # Send evaluation results (or error message) to the server. @@ -651,12 +640,17 @@ class FederatedClient: await self.netwk.send_message(messaging.Error(error)) raise RuntimeError(error) # Otherwise, run the controller's routine. - await self.fairness.fairness_round( - netwk=self.netwk, - query=query, - manager=self.trainmanager, - secagg=self._encrypter, + metrics = await self.fairness.fairness_round( + netwk=self.netwk, query=query, secagg=self._encrypter ) + # Optionally save computed fairness metrics. + if self.ckptr is not None: + self.ckptr.save_metrics( + metrics=metrics, + prefix="fairness_metrics", + append=(query.round_i > 0), + timestamp=f"round_{query.round_i}", + ) async def stop_training( self, diff --git a/declearn/main/_server.py b/declearn/main/_server.py index 7b68744b55064c272eb8075f620300cd88440b16..1286103148fa8686bfed57dc1d95abbcfadf40f2 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -579,12 +579,20 @@ class FederatedServer: ) values = self._aggregate_secagg_replies(secagg_replies).values # Have the fairness controller process results. - await self.fairness.finalize_fairness_round( + metrics = await self.fairness.finalize_fairness_round( round_i=round_i, values=values, netwk=self.netwk, secagg=self._decrypter, ) + # Optionally save computed fairness metrics. + if self.ckptr is not None: + self.ckptr.save_metrics( + metrics=metrics, + prefix="fairness_metrics", + append=(query.round_i > 0), + timestamp=f"round_{query.round_i}", + ) async def training_round( self, diff --git a/declearn/messaging/__init__.py b/declearn/messaging/__init__.py index 515a1c249de6e7ea09ea9274a4deaf0e94807106..4717448a32d34efe57b3242627fc5eaae4316c6e 100644 --- a/declearn/messaging/__init__.py +++ b/declearn/messaging/__init__.py @@ -33,8 +33,6 @@ Base messages * [Error][declearn.messaging.Error] * [EvaluationReply][declearn.messaging.EvaluationReply] * [EvaluationRequest][declearn.messaging.EvaluationRequest] -* [FairnessQuery][declearn.messaging.FairnessQuery] -* [FairnessReply][declearn.messaging.FairnessReply] * [GenericMessage][declearn.messaging.GenericMessage] * [InitRequest][declearn.messaging.InitRequest] * [InitReply][declearn.messaging.InitReply] @@ -46,6 +44,14 @@ Base messages * [TrainReply][declearn.messaging.TrainReply] * [TrainRequest][declearn.messaging.TrainRequest] +Fairness algorithms messages +---------------------------- + +* [FairnessCounts][declearn.messaging.FairnessCounts] +* [FairnessGroups][declearn.messaging.FairnessGroups] +* [FairnessQuery][declearn.messaging.FairnessQuery] +* [FairnessReply][declearn.messaging.FairnessReply] +* [FairnessSetupQuery][declearn.messaging.FairnessSetupQuery] """ from ._api import ( @@ -57,8 +63,6 @@ from ._base import ( Error, EvaluationReply, EvaluationRequest, - FairnessQuery, - FairnessReply, GenericMessage, InitRequest, InitReply, @@ -70,3 +74,10 @@ from ._base import ( TrainReply, TrainRequest, ) +from ._fairness import ( + FairnessCounts, + FairnessGroups, + FairnessQuery, + FairnessReply, + FairnessSetupQuery, +) diff --git a/declearn/messaging/_base.py b/declearn/messaging/_base.py index fe029e23d2a31aec6906853e2247d93c70492a7a..6cec34d26be39428434367ecdff0330d4a77c96b 100644 --- a/declearn/messaging/_base.py +++ b/declearn/messaging/_base.py @@ -36,8 +36,6 @@ __all__ = [ "Error", "EvaluationReply", "EvaluationRequest", - "FairnessQuery", - "FairnessReply", "GenericMessage", "InitRequest", "InitReply", @@ -102,44 +100,6 @@ class EvaluationReply(Message): return kwargs -@dataclasses.dataclass -class FairnessQuery(Message): - """Base Message for server-emitted fairness-computation queries. - - This message conveys hyper-parameters used when evaluating a model's - accuracy and/or loss over group-wise samples (from which fairness is - derived). Model weights may be attached. - - Algorithm-specific information should be conveyed using ad-hoc - messages exchanged as part of fairness-enforcement routines. - """ - - typekey = "fairness-request" - - round_i: int - batch_size: int = 32 - n_batch: Optional[int] = None - thresh: Optional[float] = None - weights: Optional[Vector] = None - - -@dataclasses.dataclass -class FairnessReply(Message): - """Base Message for client-emitted fairness-computation results. - - This message conveys results from the evaluation of a model's accuracy - and/or loss over group-wise samples (from which fairness is derived). - - This information is generically stored as a list of `values`, the - mearning and structure of which is left up to algorithm-specific - controllers. - """ - - typekey = "fairness-reply" - - values: List[float] = dataclasses.field(default_factory=list) - - @dataclasses.dataclass class GenericMessage(Message): """Generic message format, with action/params pair.""" diff --git a/declearn/messaging/_fairness.py b/declearn/messaging/_fairness.py new file mode 100644 index 0000000000000000000000000000000000000000..3af91357acc7b02a2e3d08066ecb669e08ce04cd --- /dev/null +++ b/declearn/messaging/_fairness.py @@ -0,0 +1,136 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Messages for fairness-aware federated learning setup and rounds.""" + +import dataclasses +from typing import Any, Dict, List, Optional, Tuple + +from typing_extensions import Self # future: import from typing (py >=3.11) + +from declearn.messaging import Message +from declearn.model.api import Vector + +__all__ = [ + "FairnessCounts", + "FairnessGroups", + "FairnessQuery", + "FairnessReply", + "FairnessSetupQuery", +] + + +@dataclasses.dataclass +class FairnessCounts(Message): + """Message for client-emitted sample counts across sensitive groups. + + Fields + ------ + counts: + List of group-wise sample counts, ordered based on + an agreed-upon sorted list of sensitive groups. + """ + + counts: List[int] + + typekey = "fairness-counts" + + +@dataclasses.dataclass +class FairnessGroups(Message): + """Message to exchange a list of unique sensitive group definitions. + + This message may be exchanged both ways, with clients sharing the + list of groups for which they have samples and the server sharing + back a unified, sorted list of all sensitive groups across clients. + + Fields + ------ + groups: + List of sensitive group definitions, defined by tuples of values + corresponding to those of one or more sensitive attributes and + (optionally) a target label. + """ + + groups: List[Tuple[Any, ...]] + + typekey = "fairness-groups" + + @classmethod + def from_kwargs( + cls, + **kwargs: Any, + ) -> Self: + kwargs["groups"] = [tuple(group) for group in kwargs["groups"]] + return super().from_kwargs(**kwargs) + + +@dataclasses.dataclass +class FairnessQuery(Message): + """Base Message for server-emitted fairness-computation queries. + + This message conveys hyper-parameters used when evaluating a model's + accuracy and/or loss over group-wise samples (from which fairness is + derived). Model weights may be attached. + + Algorithm-specific information should be conveyed using ad-hoc + messages exchanged as part of fairness-enforcement routines. + """ + + typekey = "fairness-request" + + round_i: int + batch_size: int = 32 + n_batch: Optional[int] = None + thresh: Optional[float] = None + weights: Optional[Vector] = None + + +@dataclasses.dataclass +class FairnessReply(Message): + """Base Message for client-emitted fairness-computation results. + + This message conveys results from the evaluation of a model's accuracy + and/or loss over group-wise samples (from which fairness is derived). + + This information is generically stored as a list of `values`, the + mearning and structure of which is left up to algorithm-specific + controllers. + """ + + typekey = "fairness-reply" + + values: List[float] = dataclasses.field(default_factory=list) + + +@dataclasses.dataclass +class FairnessSetupQuery(Message): + """Message to instruct clients to instantiate a fairness controller. + + Fields + ------ + algorithm: + Name of the algorithm, under which the target controller type + is expected to be registered. + params: + Dict of instantiation keyword arguments to the controller. + """ + + typekey = "fairness-setup-query" + + algorithm: str + params: Dict[str, Any] = dataclasses.field(default_factory=dict) diff --git a/declearn/secagg/messaging.py b/declearn/secagg/messaging.py index 9a9067053b01dd53035d7cd50e81a8dabaf6566e..a2949c418851c9f4be6fc108da765a5c33a5f8fb 100644 --- a/declearn/secagg/messaging.py +++ b/declearn/secagg/messaging.py @@ -26,6 +26,7 @@ from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.aggregator import ModelUpdates from declearn.messaging import ( EvaluationReply, + FairnessCounts, FairnessReply, Message, TrainReply, @@ -36,6 +37,7 @@ from declearn.secagg.api import Decrypter, Encrypter, SecureAggregate __all__ = [ "SecaggEvaluationReply", + "SecaggFairnessCounts", "SecaggFairnessReply", "SecaggMessage", "SecaggTrainReply", @@ -264,6 +266,42 @@ class SecaggEvaluationReply(SecaggMessage[EvaluationReply]): ) +@dataclasses.dataclass +class SecaggFairnessCounts(SecaggMessage[FairnessCounts]): + """SecAgg counterpart of the 'FairnessCounts' message class.""" + + counts: List[int] + + typekey = "secagg-fairness-counts" + + @classmethod + def from_cleartext_message( + cls, + cleartext: FairnessCounts, + encrypter: Encrypter, + ) -> Self: + counts = [encrypter.encrypt_uint(val) for val in cleartext.counts] + return cls(counts=counts) + + def decrypt_wrapped_message( + self, + decrypter: Decrypter, + ) -> FairnessCounts: + counts = [decrypter.decrypt_uint(val) for val in self.counts] + return FairnessCounts(counts=counts) + + def aggregate( + self, + other: Self, + decrypter: Decrypter, + ) -> Self: + counts = [ + decrypter.sum_encrypted([v_a, v_b]) + for v_a, v_b in zip(self.counts, other.counts) + ] + return self.__class__(counts=counts) + + @dataclasses.dataclass class SecaggFairnessReply(SecaggMessage[FairnessReply]): """SecAgg-wrapped 'FairnessReply' message."""