diff --git a/declearn/fairness/api/__init__.py b/declearn/fairness/api/__init__.py index e60366cdbd4c99e0c56d564c801d0478c9d27173..832b9e11a0819b54fb22aadec8d072a6d9598bdf 100644 --- a/declearn/fairness/api/__init__.py +++ b/declearn/fairness/api/__init__.py @@ -21,8 +21,6 @@ from ._messages import ( FairnessAccuracy, FairnessCounts, FairnessGroups, - FairnessRoundQuery, - FairnessRoundReply, SecaggFairnessAccuracy, SecaggFairnessCounts, ) diff --git a/declearn/fairness/api/_controllers.py b/declearn/fairness/api/_controllers.py index 06b25786e498f2d2d9a8b661404ea7e47108d62b..52277c0a9303c3b1d05dbebfb494a1d67b1b1b91 100644 --- a/declearn/fairness/api/_controllers.py +++ b/declearn/fairness/api/_controllers.py @@ -32,13 +32,15 @@ from declearn.communication.utils import ( from declearn.fairness.api._messages import ( FairnessCounts, FairnessGroups, - FairnessRoundQuery, SecaggFairnessCounts, ) -from declearn.fairness.core import FairnessDataset -from declearn.messaging import Error, Message, SerializedMessage +from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset +from declearn.messaging import Error, FairnessQuery, FairnessReply, Message from declearn.secagg.api import Decrypter, Encrypter -from declearn.secagg.messaging import aggregate_secagg_messages +from declearn.secagg.messaging import ( + aggregate_secagg_messages, + SecaggFairnessReply, +) from declearn.training import TrainingManager __all__ = [ @@ -170,12 +172,11 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): or may not have been altered compared with the input one. """ - @abc.abstractmethod async def fairness_round( self, netwk: NetworkClient, + query: FairnessQuery, manager: TrainingManager, - received: SerializedMessage[FairnessRoundQuery], secagg: Optional[Encrypter], ) -> None: """Participate in a round of actions to enforce fairness. @@ -184,11 +185,93 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): ---------- netwk: NetworkClient endpoint instance, connected to a server. + query: + `FairnessQuery` message to participate in a fairness round. + manager: + TrainingManager instance holding the local model, optimizer, etc. + This method may (and usually does) have side effects on this. + secagg: + Optional SecAgg encryption controller. + """ + values = self.compute_fairness_measures(query, manager) + reply = FairnessReply(values=values) + if secagg is None: + await netwk.send_message(reply) + else: + await netwk.send_message( + SecaggFairnessReply.from_cleartext_message(reply, secagg) + ) + await self.finalize_fairness_round(netwk, values, manager, secagg) + + def compute_fairness_measures( + self, + query: FairnessQuery, + manager: TrainingManager, + ) -> List[float]: + """Compute fairness measures based on a received query. + + By default, compute and return group-wise accuracy metrics, + weighted by group-wise sample counts. This may be modified + by algorithm-specific subclasses depending on algorithms' + needs. + + Parameters + ---------- + query: + `FairnessQuery` message with computational effort constraints, + and optionally model weights to assign before evaluation. + manager: + TrainingManager instance holding the model to evaluate and the + training dataset on which to do so. + + Returns + ------- + values: + Computed values, as a deterministic-length ordered list + of float values. + """ + assert isinstance(manager.train_data, FairnessDataset) + if query.weights is not None: + manager.model.set_weights(query.weights, trainable=True) + # Compute group-wise accuracy metrics. + computer = FairnessAccuracyComputer(manager.train_data) + accuracy = computer.compute_groupwise_accuracy( + model=manager.model, + batch_size=query.batch_size, + n_batch=query.n_batch, + thresh=query.thresh, + ) + # Scale computed accuracy metrics by sample counts. + accuracy = { + key: val * computer.counts[key] for key, val in accuracy.items() + } + # Gather ordered values (filling-in groups without samples). + return [accuracy.get(group, 0.0) for group in self.groups] + + @abc.abstractmethod + async def finalize_fairness_round( + self, + netwk: NetworkClient, + values: List[float], + manager: TrainingManager, + secagg: Optional[Encrypter], + ) -> None: + """Take actions to enforce fairness. + + This method is designed to be called after an initial query + has been received and responded to, resulting in computing + and sharing fairness(-related) metrics. + + Parameters + ---------- + netwk: + NetworkClient endpoint instance, connected to a server. + values: + List of locally-computed evaluation metrics, already shared + with the server for their (secure-)aggregation. manager: TrainingManager instance holding the local model, optimizer, etc. This method may (and usually does) have side effects on this. - received: - Serialized query message to participated in a fairness round. secagg: Optional SecAgg encryption controller. """ @@ -367,19 +450,27 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): """ @abc.abstractmethod - async def fairness_round( + async def finalize_fairness_round( self, round_i: int, + values: List[float], netwk: NetworkServer, secagg: Optional[Decrypter], ) -> None: """Orchestrate a round of actions to enforce fairness. + This method is designed to be called after an initial query + has been sent and responded to by clients, resulting in the + federated computation of fairness(-related) metrics. + Parameters ---------- round_i: Index of the current round (reflecting that of an upcoming training round). + values: + Aggregated metrics resulting from the fairness evaluation + run by clients at this round. netwk: NetworkServer endpoint instance, to which clients are registered. secagg: diff --git a/declearn/fairness/api/_messages.py b/declearn/fairness/api/_messages.py index eae28a455c540991f2b7a23441a9fd25f85bf6cc..7beb0785ddcfb076ad9bc37009168a5da5dd66ad 100644 --- a/declearn/fairness/api/_messages.py +++ b/declearn/fairness/api/_messages.py @@ -18,7 +18,7 @@ """API messages for fairness-aware federated learning setup and rounds.""" import dataclasses -from typing import Any, List, Optional, Tuple +from typing import Any, List, Tuple from typing_extensions import Self # future: import from typing (py >=3.11) @@ -30,8 +30,6 @@ __all__ = [ "FairnessAccuracy", "FairnessCounts", "FairnessGroups", - "FairnessRoundQuery", - "FairnessRoundReply", "SecaggFairnessAccuracy", "SecaggFairnessCounts", ] @@ -168,46 +166,3 @@ class FairnessGroups(Message): ) -> Self: kwargs["groups"] = [tuple(group) for group in kwargs["groups"]] return super().from_kwargs(**kwargs) - - -@dataclasses.dataclass -class FairnessRoundQuery(Message): - """Base Message for server-emitted fairness-computation queries. - - The base `FairnessRoundQuery` defines information that is used - when evaluating a model's accuracy and/or loss over group-wise - training samples. - - Subclasses may be defined to add algorithm-specific information. - - Fields - ------ - batch_size: - Number of samples per batch when computing metrics. - n_batch: - Optional maximum number of batches to draw per group. - If None, use the entire wrapped dataset. - thresh: - Optional binarization threshold for binary classification - models' output scores. If None, use 0.5 by default, or 0.0 - for `SklearnSGDModel` instances. - Unused for multinomial classifiers (argmax over scores). - """ - - batch_size: int = 32 - n_batch: Optional[int] = None - thresh: Optional[float] = None - - typekey = "fairness-round-query" - - -@dataclasses.dataclass -class FairnessRoundReply(Message): - """Base Message for client-emitted fairness-round end signal. - - By default this message is empty, merely noticing that things - went well. Subclasses may be used to convey algorithm-specific - results or information. - """ - - typekey = "fairness-round-reply" diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py index 1af21bad579f7d3bd9f2097aa29ad5efbfe5f9de..1fa63ac4417a90152f0c918ced1b12437ccade9d 100644 --- a/declearn/fairness/fairgrad/_client.py +++ b/declearn/fairness/fairgrad/_client.py @@ -17,7 +17,7 @@ """Client-side Fed-FairGrad controller.""" -from typing import List, Optional +from typing import Any, Dict, List, Optional from declearn.communication.api import NetworkClient @@ -27,7 +27,6 @@ from declearn.fairness.api import ( FairnessRoundQuery, FairnessRoundReply, FairnessControllerClient, - FairnessSetupQuery, SecaggFairnessAccuracy, ) from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset @@ -60,8 +59,9 @@ class FairgradControllerClient(FairnessControllerClient): async def finalize_fairness_setup( self, netwk: NetworkClient, - query: FairnessSetupQuery, manager: TrainingManager, + secagg: Optional[Encrypter], + params: Dict[str, Any], ) -> TrainingManager: assert isinstance(manager.train_data, FairnessDataset) # Set up a controller to compute group-wise model accuracy. diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 75ac07ec38183240993e75f41c9794fb0491fb78..7770c4fff14cd5386603a1da3c11cc3752ffd6f8 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -36,7 +36,6 @@ from declearn.dataset import Dataset, load_dataset_from_json from declearn.fairness.api import ( FairnessControllerClient, FairnessSetupQuery, - FairnessRoundQuery, ) from declearn.main.utils import Checkpointer from declearn.messaging import Message, SerializedMessage @@ -254,8 +253,8 @@ class FederatedClient: await self.training_round(message.deserialize()) elif issubclass(message.message_cls, messaging.EvaluationRequest): await self.evaluation_round(message.deserialize()) - elif issubclass(message.message_cls, FairnessRoundQuery): - await self.fairness_round(message) # note: keep serialized + elif issubclass(message.message_cls, messaging.FairnessQuery): + await self.fairness_round(message.deserialize()) elif issubclass(message.message_cls, SecaggSetupQuery): await self.setup_secagg(message) # note: keep serialized elif issubclass(message.message_cls, messaging.StopTraining): @@ -622,7 +621,7 @@ class FederatedClient: async def fairness_round( self, - received: SerializedMessage[FairnessRoundQuery], + query: messaging.FairnessQuery, ) -> None: """Handle a server request to run a fairness-related round. @@ -633,8 +632,8 @@ class FederatedClient: Parameters ---------- - received: - Serialized `FairnessRoundQuery` message from the server. + query: + `FairnessQuery` message from the server. Raises ------ @@ -645,9 +644,8 @@ class FederatedClient: # If no fairness controller was set up, raise a RuntimeError. if self.fairness is None: error = ( - "Received a query to participate in a fairness round " - f"('{received.message_cls.__name__}'), but no fairness " - "controller was set up." + "Received a query to participate in a fairness round, " + "but no fairness controller was set up." ) self.logger.critical(error) await self.netwk.send_message(messaging.Error(error)) @@ -655,8 +653,8 @@ class FederatedClient: # Otherwise, run the controller's routine. await self.fairness.fairness_round( netwk=self.netwk, + query=query, manager=self.trainmanager, - received=received, secagg=self._encrypter, ) diff --git a/declearn/main/_server.py b/declearn/main/_server.py index be099444bec9fb328c922bca8372edd71ff34866..7b68744b55064c272eb8075f620300cd88440b16 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -31,9 +31,9 @@ import numpy as np from declearn import messaging from declearn.communication import NetworkServerConfig from declearn.communication.api import NetworkServer -from declearn.fairness.api import FairnessControllerServer from declearn.main.config import ( EvaluateConfig, + FairnessConfig, FLOptimConfig, FLRunConfig, TrainingConfig, @@ -48,14 +48,9 @@ from declearn.metrics import MetricInputType, MetricSet from declearn.metrics._mean import MeanState from declearn.model.api import Model, Vector from declearn.optimizer.modules import AuxVar +from declearn.secagg import messaging as secagg_messaging from declearn.secagg import parse_secagg_config_server from declearn.secagg.api import Decrypter, SecaggConfigServer -from declearn.secagg.messaging import ( - SecaggEvaluationReply, - SecaggMessage, - SecaggTrainReply, - aggregate_secagg_messages, -) from declearn.utils import deserialize_object, get_logger @@ -79,7 +74,6 @@ class FederatedServer: optim: Union[FLOptimConfig, str, Dict[str, Any]], metrics: Union[MetricSet, List[MetricInputType], None] = None, secagg: Union[SecaggConfigServer, Dict[str, Any], None] = None, - fairness: Union[FairnessControllerServer, None] = None, checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None, logger: Union[logging.Logger, str, None] = None, ) -> None: @@ -108,8 +102,6 @@ class FederatedServer: secagg: SecaggConfigServer or dict or None, default=None Optional SecAgg config and setup controller or dict of kwargs to set one up. - fairness: FairnessControllerServer of None, default=None - Optional Fairness-aware Federated Learning controller. checkpoint: Checkpointer or dict or str or None, default=None Optional Checkpointer instance or instantiation dict to be used so as to save round-wise model, optimizer and metrics. @@ -133,6 +125,7 @@ class FederatedServer: self.aggrg = optim.aggregator self.optim = optim.server_opt self.c_opt = optim.client_opt + self.fairness = optim.fairness # note: optional # Assign the wrapped MetricSet. self.metrics = MetricSet.from_specs(metrics) # Assign an optional checkpointer. @@ -143,8 +136,6 @@ class FederatedServer: self.secagg = self._parse_secagg(secagg) self._decrypter = None # type: Optional[Decrypter] self._secagg_peers = set() # type: Set[str] - # Assign the optional FairnessControllerServer. - self.fairness = fairness # TODO: add proper parser and alternatives # Set up private attributes to record the loss values and best weights. self._loss = {} # type: Dict[int, float] self._best = None # type: Optional[Vector] @@ -277,7 +268,8 @@ class FederatedServer: specify the federated learning process, including clients registration, training and validation rounds' setup, plus optional elements: local differential-privacy parameters, - and/or an early-stopping criterion. + fairness evaluation parameters, and/or an early-stopping + criterion. """ # Instantiate the early-stopping criterion, if any. early_stop = None # type: Optional[EarlyStopping] @@ -293,7 +285,8 @@ class FederatedServer: round_i = 0 while True: round_i += 1 - # TODO: await self.fairness_round(round_i, config.fairness) + if self.fairness is not None: + await self.fairness_round(round_i, config.fairness) await self.training_round(round_i, config.training) await self.evaluation_round(round_i, config.evaluate) if not self._keep_training(round_i, config.rounds, early_stop): @@ -529,11 +522,69 @@ class FederatedServer: def _aggregate_secagg_replies( self, - replies: Mapping[str, SecaggMessage[MessageT]], + replies: Mapping[str, secagg_messaging.SecaggMessage[MessageT]], ) -> MessageT: """Secure-Aggregate (and decrypt) client-issued encrypted messages.""" assert self._decrypter is not None - return aggregate_secagg_messages(replies, decrypter=self._decrypter) + return secagg_messaging.aggregate_secagg_messages( + replies, decrypter=self._decrypter + ) + + async def fairness_round( + self, + round_i: int, + fairness_cfg: FairnessConfig, + ) -> None: + """Orchestrate a fairness round. + + Parameters + ---------- + round_i: + Index of the training round. + fairness_cfg: + FairnessConfig dataclass instance wrapping data-batching + and computational effort constraints hyper-parameters for + fairness evaluation. + """ + assert self.fairness is not None + # Run SecAgg setup when needed. + self.logger.info("Initiating fairness-enforcing round %s", round_i) + clients = self.netwk.client_names # FUTURE: enable sampling(?) + if self.secagg is not None and clients.difference(self._secagg_peers): + await self.setup_secagg(clients) + # Send a query to clients, including model weights when required. + query = messaging.FairnessQuery( + round_i=round_i, + batch_size=fairness_cfg.batch_size, + n_batch=fairness_cfg.n_batch, + thresh=fairness_cfg.thresh, + weights=None, + ) + await self._send_request_with_optional_weights(query, clients) + # Await and (secure-)aggregate) results. + self.logger.info("Awaiting clients' fairness measures.") + if self._decrypter is None: + replies = await self._collect_results( + clients, messaging.FairnessReply, "fairness round" + ) + if len(set(len(r.values) for r in replies.values())) != 1: + error = "Clients sent fairness values of different lengths." + self.logger.error(error) + await self.netwk.broadcast_message(messaging.Error(error)) + raise RuntimeError(error) + values = [sum(c_values) for c_values in zip(*replies.values())] + else: + secagg_replies = await self._collect_results( + clients, secagg_messaging.SecaggFairnessReply, "fairness round" + ) + values = self._aggregate_secagg_replies(secagg_replies).values + # Have the fairness controller process results. + await self.fairness.finalize_fairness_round( + round_i=round_i, + values=values, + netwk=self.netwk, + secagg=self._decrypter, + ) async def training_round( self, @@ -564,7 +615,7 @@ class FederatedServer: ) else: secagg_results = await self._collect_results( - clients, SecaggTrainReply, "training" + clients, secagg_messaging.SecaggTrainReply, "training" ) results = { "aggregated": self._aggregate_secagg_replies(secagg_results) @@ -609,7 +660,11 @@ class FederatedServer: async def _send_request_with_optional_weights( self, - msg_light: Union[messaging.TrainRequest, messaging.EvaluationRequest], + msg_light: Union[ + messaging.TrainRequest, + messaging.EvaluationRequest, + messaging.FairnessQuery, + ], clients: Set[str], ) -> None: """Send a request to clients, sparingly adding model weights to it. @@ -693,7 +748,7 @@ class FederatedServer: ) else: secagg_results = await self._collect_results( - clients, SecaggEvaluationReply, "evaluation" + clients, secagg_messaging.SecaggEvaluationReply, "evaluation" ) results = { "aggregated": self._aggregate_secagg_replies(secagg_results) diff --git a/declearn/main/config/__init__.py b/declearn/main/config/__init__.py index c26bb7eef8287fac3cdfbad24678055809e9456c..42e7c0e5cc1c5d869d9efbe35068eb09e1452a00 100644 --- a/declearn/main/config/__init__.py +++ b/declearn/main/config/__init__.py @@ -33,6 +33,8 @@ The following dataclasses are articulated by `FLRunConfig`: * [EvaluateConfig][declearn.main.config.EvaluateConfig]: Hyper-parameters for an evaluation round. +* [FairnessConfig][declearn.main.config.FairnessConfig]: + Dataclass wrapping parameters for fairness evaluation rounds. * [RegisterConfig][declearn.main.config.RegisterConfig]: Hyper-parameters for clients registration. * [TrainingConfig][declearn.main.config.TrainingConfig]: @@ -41,6 +43,7 @@ The following dataclasses are articulated by `FLRunConfig`: from ._dataclasses import ( EvaluateConfig, + FairnessConfig, PrivacyConfig, RegisterConfig, TrainingConfig, diff --git a/declearn/main/config/_dataclasses.py b/declearn/main/config/_dataclasses.py index 0c4f56148de97402d3aed1355b164d3486870d55..867343e689aa097273bd238bee54dbd32739e82f 100644 --- a/declearn/main/config/_dataclasses.py +++ b/declearn/main/config/_dataclasses.py @@ -22,6 +22,7 @@ from typing import Any, Dict, Optional, Tuple __all__ = [ "EvaluateConfig", + "FairnessConfig", "PrivacyConfig", "RegisterConfig", "TrainingConfig", @@ -223,3 +224,30 @@ class PrivacyConfig: accountants = ("rdp", "gdp", "prv") if self.accountant not in accountants: raise TypeError(f"'accountant' should be one of {accountants}") + + +@dataclasses.dataclass +class FairnessConfig: + """Dataclass wrapping parameters for fairness evaluation rounds. + + The parameters wrapped by this class are those of + `declearn.fairness.core.FairnessAccuracyComputer` + metrics-computation methods. + + Attributes + ---------- + batch_size: int + Number of samples per processed data batch. + n_batch: int or None, default=None + Optional maximum number of batches to draw. + If None, use the entire training dataset. + thresh: float or None, default=None + Optional binarization threshold for binary classification + models' output scores. If None, use 0.5 by default, or 0.0 + for `SklearnSGDModel` instances. + Unused for multinomial classifiers (argmax over scores). + """ + + batch_size: int = 32 + n_batch: Optional[int] = None + thresh: Optional[float] = None diff --git a/declearn/main/config/_run_config.py b/declearn/main/config/_run_config.py index ec68dbbd5353d59d13955547656e48bf3f9d42ac..25b9359731491a77900a8b88fc3483818e4d4165 100644 --- a/declearn/main/config/_run_config.py +++ b/declearn/main/config/_run_config.py @@ -25,6 +25,7 @@ from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.main.utils import EarlyStopConfig from declearn.main.config._dataclasses import ( EvaluateConfig, + FairnessConfig, PrivacyConfig, RegisterConfig, TrainingConfig, @@ -66,6 +67,10 @@ class FLRunConfig(TomlConfig): and data-batching instructions. - evaluate: EvaluateConfig Parameters for validation rounds, similar to training ones. + - fairness: FairnessConfig or None + Parameters for fairness evaluation rounds. + Only used when an algorithm to enforce fairness is set up, + as part of the process's federated optimization configuration. - privacy: PrivacyConfig or None Optional parameters to set up local differential privacy, by having clients use the DP-SGD algorithm for training. @@ -90,12 +95,15 @@ class FLRunConfig(TomlConfig): batch size will be used for evaluation as well. - If `privacy` is provided and the 'poisson' parameter is unspecified for `training`, it will be set to True by default rather than False. + - If `fairness` is not provided or lacks a 'batch_size' parameter, + that of evaluation (or, by extension, training) will be used. """ rounds: int register: RegisterConfig training: TrainingConfig evaluate: EvaluateConfig + fairness: FairnessConfig privacy: Optional[PrivacyConfig] = None early_stop: Optional[EarlyStopConfig] = None # type: ignore # is a type @@ -128,7 +136,7 @@ class FLRunConfig(TomlConfig): # If evaluation batch size is not set, use the same as training. # Note: if inputs have invalid formats, let the parent method fail. evaluate = kwargs.setdefault("evaluate", {}) - if isinstance(evaluate, dict): + if isinstance(evaluate, dict) and ("batch_size" not in evaluate): training = kwargs.get("training") if isinstance(training, dict): evaluate.setdefault("batch_size", training.get("batch_size")) @@ -141,5 +149,14 @@ class FLRunConfig(TomlConfig): training = kwargs.get("training") if isinstance(training, dict): training.setdefault("poisson", True) + # If fairness batch size is not set, use the same as evaluation. + # Note: if inputs have invalid formats, let the parent method fail. + fairness = kwargs.setdefault("fairness", {}) + if isinstance(fairness, dict) and ("batch_size" not in fairness): + evaluate = kwargs.get("evaluate") + if isinstance(evaluate, dict): + fairness.setdefault("batch_size", evaluate.get("batch_size")) + elif isinstance(evaluate, EvaluateConfig): + fairness.setdefault("batch_size", evaluate.batch_size) # Delegate the rest of the work to the parent method. return super().from_params(**kwargs) diff --git a/declearn/main/config/_strategy.py b/declearn/main/config/_strategy.py index 333072a74f9d122aa4d76df2fae8347c4ebad5cd..707b93a22532e50712f1e0b41a420354be4e9927 100644 --- a/declearn/main/config/_strategy.py +++ b/declearn/main/config/_strategy.py @@ -19,10 +19,11 @@ import dataclasses import functools -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union from declearn.aggregator import Aggregator, AveragingAggregator +from declearn.fairness.api import FairnessControllerServer from declearn.optimizer import Optimizer from declearn.utils import TomlConfig, access_registered, deserialize_object @@ -59,6 +60,9 @@ class FLOptimConfig(TomlConfig): - aggregator: Aggregator, default=AverageAggregator() Client weights aggregator to be used by the server so as to conduct the round-wise aggregation of client udpates. + - fairness: Fairness or None, default=None + Optional `FairnessControllerServer` instance specifying + an algorithm to enforce fairness of the trained model. Notes ----- @@ -98,6 +102,7 @@ class FLOptimConfig(TomlConfig): aggregator: Aggregator = dataclasses.field( default_factory=AveragingAggregator ) + fairness: Optional[FairnessControllerServer] = None @classmethod def parse_client_opt( diff --git a/declearn/messaging/__init__.py b/declearn/messaging/__init__.py index 17f5e5e42fd6be0fbe6566dbcd68817937447f67..515a1c249de6e7ea09ea9274a4deaf0e94807106 100644 --- a/declearn/messaging/__init__.py +++ b/declearn/messaging/__init__.py @@ -33,6 +33,8 @@ Base messages * [Error][declearn.messaging.Error] * [EvaluationReply][declearn.messaging.EvaluationReply] * [EvaluationRequest][declearn.messaging.EvaluationRequest] +* [FairnessQuery][declearn.messaging.FairnessQuery] +* [FairnessReply][declearn.messaging.FairnessReply] * [GenericMessage][declearn.messaging.GenericMessage] * [InitRequest][declearn.messaging.InitRequest] * [InitReply][declearn.messaging.InitReply] @@ -55,6 +57,8 @@ from ._base import ( Error, EvaluationReply, EvaluationRequest, + FairnessQuery, + FairnessReply, GenericMessage, InitRequest, InitReply, diff --git a/declearn/messaging/_base.py b/declearn/messaging/_base.py index 6cec34d26be39428434367ecdff0330d4a77c96b..fe029e23d2a31aec6906853e2247d93c70492a7a 100644 --- a/declearn/messaging/_base.py +++ b/declearn/messaging/_base.py @@ -36,6 +36,8 @@ __all__ = [ "Error", "EvaluationReply", "EvaluationRequest", + "FairnessQuery", + "FairnessReply", "GenericMessage", "InitRequest", "InitReply", @@ -100,6 +102,44 @@ class EvaluationReply(Message): return kwargs +@dataclasses.dataclass +class FairnessQuery(Message): + """Base Message for server-emitted fairness-computation queries. + + This message conveys hyper-parameters used when evaluating a model's + accuracy and/or loss over group-wise samples (from which fairness is + derived). Model weights may be attached. + + Algorithm-specific information should be conveyed using ad-hoc + messages exchanged as part of fairness-enforcement routines. + """ + + typekey = "fairness-request" + + round_i: int + batch_size: int = 32 + n_batch: Optional[int] = None + thresh: Optional[float] = None + weights: Optional[Vector] = None + + +@dataclasses.dataclass +class FairnessReply(Message): + """Base Message for client-emitted fairness-computation results. + + This message conveys results from the evaluation of a model's accuracy + and/or loss over group-wise samples (from which fairness is derived). + + This information is generically stored as a list of `values`, the + mearning and structure of which is left up to algorithm-specific + controllers. + """ + + typekey = "fairness-reply" + + values: List[float] = dataclasses.field(default_factory=list) + + @dataclasses.dataclass class GenericMessage(Message): """Generic message format, with action/params pair.""" diff --git a/declearn/secagg/messaging.py b/declearn/secagg/messaging.py index 43a67167658616993a8c54a3077752e568c42c36..9a9067053b01dd53035d7cd50e81a8dabaf6566e 100644 --- a/declearn/secagg/messaging.py +++ b/declearn/secagg/messaging.py @@ -19,18 +19,24 @@ import abc import dataclasses -from typing import Dict, Generic, Mapping, TypeVar +from typing import Dict, Generic, List, Mapping, TypeVar from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.aggregator import ModelUpdates -from declearn.messaging import EvaluationReply, Message, TrainReply +from declearn.messaging import ( + EvaluationReply, + FairnessReply, + Message, + TrainReply, +) from declearn.metrics import MetricState from declearn.optimizer.modules import AuxVar from declearn.secagg.api import Decrypter, Encrypter, SecureAggregate __all__ = [ "SecaggEvaluationReply", + "SecaggFairnessReply", "SecaggMessage", "SecaggTrainReply", "aggregate_secagg_messages", @@ -256,3 +262,44 @@ class SecaggEvaluationReply(SecaggMessage[EvaluationReply]): return self.__class__( loss=loss, n_steps=n_steps, t_spent=t_spent, metrics=metrics ) + + +@dataclasses.dataclass +class SecaggFairnessReply(SecaggMessage[FairnessReply]): + """SecAgg-wrapped 'FairnessReply' message.""" + + typekey = "secagg_fairness_reply" + + values: List[int] + + @classmethod + def from_cleartext_message( + cls, + cleartext: FairnessReply, + encrypter: Encrypter, + ) -> Self: + values = [encrypter.encrypt_float(value) for value in cleartext.values] + return cls(values=values) + + def decrypt_wrapped_message( + self, + decrypter: Decrypter, + ) -> FairnessReply: + values = [decrypter.decrypt_float(value) for value in self.values] + return FairnessReply(values=values) + + def aggregate( + self, + other: Self, + decrypter: Decrypter, + ) -> Self: + if len(self.values) != len(other.values): + raise ValueError( + "Cannot aggregate SecAgg-protected fairness values with " + "distinct shapes." + ) + values = [ + decrypter.sum_encrypted([v_a, v_b]) + for v_a, v_b in zip(self.values, other.values) + ] + return self.__class__(values=values)