Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit e599b23a authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Revise fairness controllers API and integration.

- Revise the way client-side controllers are instantiated from
  server-emitted instructions, using the usual type-registration
  tools and removing the need to subclass setup query messages.
- Move fairness-related base messages to the 'messaging' module
  (and 'secagg.messaging' one).
- Have the client-side controller access and wrap a training
  manager at instantiation. Revise method signatures and docs.
- Add fairness metrics checkpointing.
parent 604620ee
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
......@@ -17,15 +17,5 @@
"""Draft API for Fairness-aware Federated Learning algorithms."""
from ._messages import (
FairnessAccuracy,
FairnessCounts,
FairnessGroups,
SecaggFairnessAccuracy,
SecaggFairnessCounts,
)
from ._controllers import (
FairnessControllerClient,
FairnessControllerServer,
FairnessSetupQuery,
)
from ._client import FairnessControllerClient
from ._server import FairnessControllerServer
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client-side ABC for fairness-aware federated learning controllers."""
import abc
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
import numpy as np
from declearn.communication.api import NetworkClient
from declearn.communication.utils import verify_server_message_validity
from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset
from declearn.messaging import (
Error,
FairnessCounts,
FairnessGroups,
FairnessQuery,
FairnessReply,
FairnessSetupQuery,
)
from declearn.secagg.api import Encrypter
from declearn.secagg.messaging import (
SecaggFairnessCounts,
SecaggFairnessReply,
)
from declearn.training import TrainingManager
from declearn.utils import (
access_registered,
create_types_registry,
register_type,
)
__all__ = [
"FairnessControllerClient",
]
@create_types_registry(name="FairnessControllerClient")
class FairnessControllerClient(metaclass=abc.ABCMeta):
"""Abstract base class for client-side fairness controllers."""
algorithm: ClassVar[str]
"""Name of the fairness-enforcing algorithm.
This name should be unique across 'FairnessControllerClient' classes,
and shared with a unique paired 'FairnessControllerServer'. It is used
for type-registration and to enable instantiating a client controller
based on server-emitted instructions in a federated setting.
"""
def __init_subclass__(
cls,
register: bool = True,
) -> None:
"""Automatically type-register subclasses."""
if register:
register_type(cls, cls.algorithm, group="FairnessControllerClient")
def __init__(
self,
manager: TrainingManager,
) -> None:
"""Instantiate the client-side fairness controller.
Parameters
----------
manager:
`TrainingManager` instance wrapping the model being trained
and its training dataset (that must be a `FairnessDataset`).
"""
if not isinstance(manager.train_data, FairnessDataset):
raise TypeError(
"Cannot set up fairness without a 'FairnessDataset' "
"as training dataset."
)
self.manager = manager
self.computer = FairnessAccuracyComputer(manager.train_data)
self.groups = [] # type: List[Tuple[Any, ...]]
@staticmethod
def from_setup_query(
query: FairnessSetupQuery,
manager: TrainingManager,
) -> "FairnessControllerClient":
"""Instantiate a controller from a server-emitted query.
Parameters
----------
query:
`FairnessSetupQuery` received from the server.
manager:
`TrainingManager` wrapping the model to train.
Returns
-------
controller:
`FairnessControllerClient` instance, the type and parameters
of which depend on the input `query`, that wraps `manager`.
"""
try:
cls = access_registered(
name=query.algorithm, group="FairnessControllerClient"
)
assert issubclass(cls, FairnessControllerClient)
except Exception as exc:
raise ValueError(
"Failed to retrieve a 'FairnessControllerClient' class "
"matching the input 'FairnessSetupQuery' message."
) from exc
return cls(manager=manager, **query.params)
async def setup_fairness(
self,
netwk: NetworkClient,
secagg: Optional[Encrypter],
) -> None:
"""Participate in a routine to initialize fairness-aware learning.
This routine has the following structure:
- Exchange with the server to agree on an ordered list of sensitive
groups defined by the interesection of 1+ sensitive attributes
and (opt.) a classification target label.
- Send (encrypted) group-wise training sample counts, that the server
is to (secure-)aggregate.
- Perform any additional actions specific to the algorithm in use.
- On the client side, optionally alter the `TrainingManager` used.
- On the server side, optionally alter the `Aggregator` used.
Parameters
----------
netwk:
NetworkClient endpoint, registered to a server.
secagg:
Optional SecAgg encryption controller.
"""
# Share sensitive groups definitions and received an ordered list.
self.groups = await self._exchange_sensitive_groups_list(netwk)
# Send group-wise sample counts for the server to (secure-)aggregate.
await self._send_sensitive_groups_counts(netwk, secagg)
# Run additional algorithm-specific setup steps.
await self.finalize_fairness_setup(netwk, secagg)
async def _exchange_sensitive_groups_list(
self,
netwk: NetworkClient,
) -> List[Tuple[Any, ...]]:
"""Exhange sensitive groups definitions and return a unified list."""
# Gather local sensitive groups and their sample counts.
counts = self.computer.counts
groups = list(counts)
# Share them and receive a unified, ordered list of groups.
await netwk.send_message(FairnessGroups(groups=groups))
received = await netwk.recv_message()
message = await verify_server_message_validity(
netwk, received, expected=FairnessGroups
)
return message.groups
async def _send_sensitive_groups_counts(
self,
netwk: NetworkClient,
secagg: Optional[Encrypter],
) -> None:
"""Send (opt. encrypted) group-wise sample counts to the server."""
counts = self.computer.counts
reply = FairnessCounts([counts.get(group, 0) for group in self.groups])
if secagg is None:
await netwk.send_message(reply)
else:
await netwk.send_message(
SecaggFairnessCounts.from_cleartext_message(reply, secagg)
)
@abc.abstractmethod
async def finalize_fairness_setup(
self,
netwk: NetworkClient,
secagg: Optional[Encrypter],
) -> None:
"""Finalize the fairness setup routine.
This method is called as part of `setup_fairness`, and should
be defined by concrete subclasses to implement setup behavior
once the initial echange of sensitive group definitions and
sample counts has been performed.
Parameters
----------
netwk:
NetworkClient endpoint, registered to a server.
secagg:
Optional SecAgg encryption controller.
"""
async def fairness_round(
self,
netwk: NetworkClient,
query: FairnessQuery,
secagg: Optional[Encrypter],
) -> Dict[str, Union[float, np.ndarray]]:
"""Participate in a round of actions to enforce fairness.
Parameters
----------
netwk:
NetworkClient endpoint instance, connected to a server.
query:
`FairnessQuery` message to participate in a fairness round.
secagg:
Optional SecAgg encryption controller.
Returns
-------
metrics:
Computed local fairness(-related) metrics computed as part
of this routine, as a dict mapping scalar or numpy array
values with their name.
"""
try:
values = await self._compute_and_share_fairness_measures(
netwk, query, secagg
)
except Exception as exc:
error = f"Error encountered in fairness round: {repr(exc)}"
self.manager.logger.error(error)
await netwk.send_message(Error(error))
raise RuntimeError(error) from exc
# Run additional algorithm-specific steps.
return await self.finalize_fairness_round(netwk, values, secagg)
async def _compute_and_share_fairness_measures(
self,
netwk: NetworkClient,
query: FairnessQuery,
secagg: Optional[Encrypter],
) -> List[float]:
"""Compute, share (encrypted) and return fairness measures."""
# Optionally update the wrapped model's weights.
if query.weights is not None:
self.manager.model.set_weights(query.weights, trainable=True)
# Compute, opt. encrypt and share fairness-related metrics.
values = self.compute_fairness_measures(
query.batch_size, query.n_batch, query.thresh
)
reply = FairnessReply(values=values)
if secagg is None:
await netwk.send_message(reply)
else:
await netwk.send_message(
SecaggFairnessReply.from_cleartext_message(reply, secagg)
)
# Return computed values.
return values
@abc.abstractmethod
def compute_fairness_measures(
self,
batch_size: int,
n_batch: Optional[int] = None,
thresh: Optional[float] = None,
) -> List[float]:
"""Compute fairness measures based on a received query.
By default, compute and return group-wise accuracy metrics,
weighted by group-wise sample counts. This may be modified
by algorithm-specific subclasses depending on algorithms'
needs.
Parameters
----------
batch_size:
Number of samples per batch when computing predictions.
n_batch:
Optional maximum number of batches to draw per category.
If None, use the entire wrapped dataset.
thresh:
Optional binarization threshold for binary classification
models' output scores. If None, use 0.5 by default, or 0.0
for `SklearnSGDModel` instances.
Unused for multinomial classifiers (argmax over scores).
Returns
-------
values:
Computed values, as a deterministic-length ordered list
of float values.
"""
@abc.abstractmethod
async def finalize_fairness_round(
self,
netwk: NetworkClient,
values: List[float],
secagg: Optional[Encrypter],
) -> Dict[str, Union[float, np.ndarray]]:
"""Take actions to enforce fairness.
This method is designed to be called after an initial query
has been received and responded to, resulting in computing
and sharing fairness(-related) metrics.
Parameters
----------
netwk:
NetworkClient endpoint instance, connected to a server.
values:
List of locally-computed evaluation metrics, already shared
with the server for their (secure-)aggregation.
secagg:
Optional SecAgg encryption controller.
Returns
-------
metrics:
Computed local fairness(-related) metrics computed as part
of this routine, as a dict mapping scalar or numpy array
values with their name.
"""
......@@ -15,290 +15,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Draft API for Fairness-aware Federated Learning."""
"""Server-side ABC for fairness-aware federated learning controllers."""
import abc
import dataclasses
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
import numpy as np
from declearn.aggregator import Aggregator
from declearn.communication.api import NetworkClient, NetworkServer
from declearn.communication.utils import (
verify_client_messages_validity,
verify_server_message_validity,
)
from declearn.fairness.api._messages import (
from declearn.communication.api import NetworkServer
from declearn.communication.utils import verify_client_messages_validity
from declearn.messaging import (
FairnessCounts,
FairnessGroups,
SecaggFairnessCounts,
FairnessSetupQuery,
SerializedMessage,
)
from declearn.fairness.core import FairnessAccuracyComputer, FairnessDataset
from declearn.messaging import Error, FairnessQuery, FairnessReply, Message
from declearn.secagg.api import Decrypter, Encrypter
from declearn.secagg.api import Decrypter
from declearn.secagg.messaging import (
aggregate_secagg_messages,
SecaggFairnessReply,
SecaggFairnessCounts,
)
from declearn.training import TrainingManager
from declearn.utils import create_types_registry, register_type
__all__ = [
"FairnessControllerClient",
"FairnessControllerServer",
"FairnessSetupQuery",
]
class FairnessControllerClient(metaclass=abc.ABCMeta):
"""Abstract base class for client-side fairness controllers."""
def __init__(
self,
) -> None:
"""Instantiate the client-side fairness controller."""
self.groups = [] # type: List[Tuple[Any, ...]]
async def setup_fairness(
self,
netwk: NetworkClient,
manager: TrainingManager,
secagg: Optional[Encrypter],
params: Dict[str, Any],
) -> TrainingManager:
"""Participate in a routine to initialize fairness-aware learning.
This routine has the following structure:
- Exchange with the server to agree on an ordered list of sensitive
groups defined by the interesection of 1+ sensitive attributes
and (opt.) a classification target label.
- Send (encrypted) group-wise training sample counts, that the server
is to (secure-)aggregate.
- Perform any additional actions specific to the algorithm in use.
- On the client side, optionally alter the `TrainingManager` used.
- On the server side, optionally alter the `Aggregator` used.
Parameters
----------
netwk:
NetworkClient endpoint, registered to a server.
manager:
TrainingManager instance that was set up notwithstanding fairness.
secagg:
Optional SecAgg encryption controller.
params:
Dict of algorithm-specific keyword arguments received from
the server as part of the query that triggered this routine.
Warns
-----
RuntimeWarning
If the returned training manager differs from the input one.
Returns
-------
manager:
`TrainingManager` instance to use in the FL process, that may
or may not have been altered compared with the input one.
"""
# Verify that a training 'FairnessDataset' is available.
if not isinstance(manager.train_data, FairnessDataset):
msg = "Cannot set up fairness without a 'FairnessDataset'."
await netwk.send_message(Error(msg))
raise TypeError(msg)
# Gather local sensitive groups and their sample counts.
counts = manager.train_data.get_sensitive_group_counts()
groups = list(counts)
# Share them and receive a unified, ordered list of groups.
await netwk.send_message(FairnessGroups(groups=groups))
received = await netwk.recv_message()
message = await verify_server_message_validity(
netwk, received, expected=FairnessGroups
)
self.groups = message.groups
# Sort and fill out sample counts, opt. encrypt them and send them.
reply = FairnessCounts([counts.get(group, 0) for group in self.groups])
if secagg is None:
await netwk.send_message(reply)
else:
await netwk.send_message(
SecaggFairnessCounts.from_cleartext_message(reply, secagg)
)
# Run additional algorithm-specific setup steps.
return await self.finalize_fairness_setup(
netwk, manager, secagg, params
)
@abc.abstractmethod
async def finalize_fairness_setup(
self,
netwk: NetworkClient,
manager: TrainingManager,
secagg: Optional[Encrypter],
params: Dict[str, Any],
) -> TrainingManager:
"""Finalize the fairness setup routine and return an Aggregator.
This method is called as part of `setup_fairness`, and should
be defined by concrete subclasses to implement setup behavior
once the initial query/reply messages have been exchanged.
The returned `TrainingManager` may either be the input `manager`
or a new or modified version of it, depending on the needs of
the fairness-aware federated learning process being implemented.
Parameters
----------
netwk:
NetworkClient endpoint, registered to a server.
manager:
TrainingManager instance that was set up notwithstanding fairness.
secagg:
Optional SecAgg encryption controller.
params:
Dict of algorithm-specific keyword arguments received from
the server as part of the query that triggered this routine.
Warns
-----
RuntimeWarning
If the returned training manager differs from the input one.
Returns
-------
manager:
`TrainingManager` instance to use in the FL process, that may
or may not have been altered compared with the input one.
"""
async def fairness_round(
self,
netwk: NetworkClient,
query: FairnessQuery,
manager: TrainingManager,
secagg: Optional[Encrypter],
) -> None:
"""Participate in a round of actions to enforce fairness.
Parameters
----------
netwk:
NetworkClient endpoint instance, connected to a server.
query:
`FairnessQuery` message to participate in a fairness round.
manager:
TrainingManager instance holding the local model, optimizer, etc.
This method may (and usually does) have side effects on this.
secagg:
Optional SecAgg encryption controller.
"""
values = self.compute_fairness_measures(query, manager)
reply = FairnessReply(values=values)
if secagg is None:
await netwk.send_message(reply)
else:
await netwk.send_message(
SecaggFairnessReply.from_cleartext_message(reply, secagg)
)
await self.finalize_fairness_round(netwk, values, manager, secagg)
def compute_fairness_measures(
self,
query: FairnessQuery,
manager: TrainingManager,
) -> List[float]:
"""Compute fairness measures based on a received query.
By default, compute and return group-wise accuracy metrics,
weighted by group-wise sample counts. This may be modified
by algorithm-specific subclasses depending on algorithms'
needs.
Parameters
----------
query:
`FairnessQuery` message with computational effort constraints,
and optionally model weights to assign before evaluation.
manager:
TrainingManager instance holding the model to evaluate and the
training dataset on which to do so.
Returns
-------
values:
Computed values, as a deterministic-length ordered list
of float values.
"""
assert isinstance(manager.train_data, FairnessDataset)
if query.weights is not None:
manager.model.set_weights(query.weights, trainable=True)
# Compute group-wise accuracy metrics.
computer = FairnessAccuracyComputer(manager.train_data)
accuracy = computer.compute_groupwise_accuracy(
model=manager.model,
batch_size=query.batch_size,
n_batch=query.n_batch,
thresh=query.thresh,
)
# Scale computed accuracy metrics by sample counts.
accuracy = {
key: val * computer.counts[key] for key, val in accuracy.items()
}
# Gather ordered values (filling-in groups without samples).
return [accuracy.get(group, 0.0) for group in self.groups]
@abc.abstractmethod
async def finalize_fairness_round(
self,
netwk: NetworkClient,
values: List[float],
manager: TrainingManager,
secagg: Optional[Encrypter],
) -> None:
"""Take actions to enforce fairness.
This method is designed to be called after an initial query
has been received and responded to, resulting in computing
and sharing fairness(-related) metrics.
Parameters
----------
netwk:
NetworkClient endpoint instance, connected to a server.
values:
List of locally-computed evaluation metrics, already shared
with the server for their (secure-)aggregation.
manager:
TrainingManager instance holding the local model, optimizer, etc.
This method may (and usually does) have side effects on this.
secagg:
Optional SecAgg encryption controller.
"""
@create_types_registry(name="FairnessControllerServer")
class FairnessControllerServer(metaclass=abc.ABCMeta):
"""Abstract base class for server-side fairness controllers."""
@dataclasses.dataclass
class FairnessSetupQuery(Message, register=False, metaclass=abc.ABCMeta):
"""ABC message for all Fairness setup init requests.
algorithm: ClassVar[str]
"""Name of the fairness-enforcing algorithm.
This message should be subclassed into algorithm-specific messages.
This name should be unique across 'FairnessControllerServer' classes,
and shared with a unique paired 'FairnessControllerClient'. It is used
for type-registration and to enable instructing clients to instantiate
a controller matching that chosen by the server in a federated setting.
"""
@abc.abstractmethod
def instantiate_controller(
self,
) -> FairnessControllerClient:
"""Instantiate a `FairnessControllerClient` matching this query."""
def get_setup_params(
self,
) -> Dict[str, Any]:
"""Return a dict of parameters to pass to the client setup routine."""
return {}
class FairnessControllerServer(metaclass=abc.ABCMeta):
"""Abstract base class for server-side fairness controllers."""
def __init_subclass__(
cls,
register: bool = True,
) -> None:
"""Automatically type-register subclasses."""
if register:
register_type(cls, cls.algorithm, group="FairnessControllerServer")
def __init__(
self,
......@@ -328,8 +93,8 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
This routine has the following structure:
- Send a setup query to clients, the type of which depends
on the actual fairness-enforcing algorithm used.
- Send a setup query to clients, resulting in the instantiation
of client-side controllers matching this one.
- Exchange with clients to agree on an ordered list of sensitive
groups defined by the interesection of 1+ sensitive attributes
and (opt.) a classification target label.
......@@ -363,39 +128,30 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
query = self.prepare_fairness_setup_query()
await netwk.broadcast_message(query)
# Receive, aggregate, assign and send back sensitive group definitions.
await self._exchange_sensitive_groups_list(netwk)
# Wait for group-wise sample counts from clients.
received = await netwk.wait_for_messages()
# When SecAgg is not used, expect cleartext group-wise counts.
if secagg is None:
replies = await verify_client_messages_validity(
netwk, received, expected=FairnessCounts
)
counts = self._aggregate_cleartext_counts(replies)
# When SecAgg is used, expect and secure-aggregate encrypted counts.
else:
sec_rep = await verify_client_messages_validity(
netwk, received, expected=SecaggFairnessCounts
)
counts = aggregate_secagg_messages(sec_rep, secagg).counts
self.groups = await self._exchange_sensitive_groups_list(netwk)
# Receive, (secure-)aggregate and return group-wise sample counts.
counts = await self._aggregate_sensitive_groups_counts(netwk, secagg)
# Run additional algorithm-specific setup steps.
return await self.finalize_fairness_setup(netwk, counts, aggregator)
def _aggregate_cleartext_counts(
def prepare_fairness_setup_query(
self,
messages: Dict[str, FairnessCounts],
) -> List[int]:
"""Sum group-wise sample counts received from clients."""
counts = np.zeros(len(self.groups), dtype="uint64")
for message in messages.values():
counts += np.asarray(message.counts, dtype="uint64")
return counts.tolist()
) -> FairnessSetupQuery:
"""Return a request to setup fairness, broadcastable to clients.
Returns
-------
message:
`FairnessSetupQuery` instance to be sent to clients in order
to trigger the Fairness setup protocol.
"""
return FairnessSetupQuery(algorithm=self.algorithm)
@staticmethod
async def _exchange_sensitive_groups_list(
self,
netwk: NetworkServer,
) -> None:
"""Receive, aggregate, assign and share sensitive group definitions."""
) -> List[Tuple[Any, ...]]:
"""Receive, aggregate, share and return sensitive group definitions."""
received = await netwk.wait_for_messages()
# Verify and deserialize client-wise sensitive group definitions.
messages = await verify_client_messages_validity(
......@@ -403,22 +159,53 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
)
# Gather the sorted union of all existing definitions.
unique = {group for msg in messages.values() for group in msg.groups}
self.groups = sorted(list(unique))
groups = sorted(list(unique))
# Send it to clients, and expect their reply (encrypted counts).
await netwk.broadcast_message(FairnessGroups(groups=self.groups))
await netwk.broadcast_message(FairnessGroups(groups=groups))
return groups
@abc.abstractmethod
def prepare_fairness_setup_query(
async def _aggregate_sensitive_groups_counts(
self,
) -> FairnessSetupQuery:
"""Return a request to setup fairness, broadcastable to clients.
netwk: NetworkServer,
secagg: Optional[Decrypter],
) -> List[int]:
"""Receive, (secure-)aggregate and return group-wise sample counts."""
received = await netwk.wait_for_messages()
if secagg is None:
return await self._aggregate_sensitive_groups_counts_cleartext(
netwk=netwk, received=received, n_groups=len(self.groups)
)
return await self._aggregate_sensitive_groups_counts_encrypted(
netwk=netwk, received=received, decrypter=secagg
)
Returns
-------
message:
`FairnessSetupQuery` subclass instance to be sent to clients
in order to trigger the Fairness setup protocol.
"""
@staticmethod
async def _aggregate_sensitive_groups_counts_cleartext(
netwk: NetworkServer,
received: Dict[str, SerializedMessage],
n_groups: int,
) -> List[int]:
"""Deserialize and aggregate cleartext group-wise counts."""
replies = await verify_client_messages_validity(
netwk, received, expected=FairnessCounts
)
counts = np.zeros(n_groups, dtype="uint64")
for message in replies.values():
counts = counts + np.asarray(message.counts, dtype="uint64")
return counts.tolist()
@staticmethod
async def _aggregate_sensitive_groups_counts_encrypted(
netwk: NetworkServer,
received: Dict[str, SerializedMessage],
decrypter: Decrypter,
) -> List[int]:
"""Deserialize and secure-aggregate encrypted group-wise counts."""
replies = await verify_client_messages_validity(
netwk, received, expected=SecaggFairnessCounts
)
aggregated = aggregate_secagg_messages(replies, decrypter)
return aggregated.counts
@abc.abstractmethod
async def finalize_fairness_setup(
......@@ -456,7 +243,7 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
values: List[float],
netwk: NetworkServer,
secagg: Optional[Decrypter],
) -> None:
) -> Dict[str, Union[float, np.ndarray]]:
"""Orchestrate a round of actions to enforce fairness.
This method is designed to be called after an initial query
......@@ -475,4 +262,11 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
NetworkServer endpoint instance, to which clients are registered.
secagg:
Optional SecAgg decryption controller.
Returns
-------
metrics:
Computed local fairness(-related) metrics computed as part
of this routine, as a dict mapping scalar or numpy array
values with their name.
"""
......@@ -33,16 +33,13 @@ from declearn.communication.utils import (
verify_server_message_validity,
)
from declearn.dataset import Dataset, load_dataset_from_json
from declearn.fairness.api import (
FairnessControllerClient,
FairnessSetupQuery,
)
from declearn.fairness.api import FairnessControllerClient
from declearn.main.utils import Checkpointer
from declearn.messaging import Message, SerializedMessage
from declearn.training import TrainingManager
from declearn.secagg import parse_secagg_config_client
from declearn.secagg.api import Encrypter, SecaggConfigClient, SecaggSetupQuery
from declearn.secagg.messaging import SecaggEvaluationReply, SecaggTrainReply
from declearn.secagg import messaging as secagg_messaging
from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger
......@@ -457,27 +454,18 @@ class FederatedClient:
and should never be called in another context.
"""
assert self.trainmanager is not None
# Parse the serialized FairnessSetupQuery.
try:
received = await self.netwk.recv_message()
query = await verify_server_message_validity(
netwk=self.netwk,
received=received,
expected=FairnessSetupQuery, # type: ignore[type-abstract]
)
except Exception as exc:
error = "Failed to parse fairness setup query."
self.logger.critical(error)
await self.netwk.send_message(messaging.Error(error))
raise RuntimeError(error) from exc
# Await and deserialize a FairnessSetupQuery.
received = await self.netwk.recv_message()
query = await verify_server_message_validity(
self.netwk, received, expected=messaging.FairnessSetupQuery
)
# Instantiate a FairnessControllerClient and run its setup routine.
try:
self.fairness = query.instantiate_controller()
self.trainmanager = await self.fairness.setup_fairness(
netwk=self.netwk,
manager=self.trainmanager,
secagg=self._encrypter,
params=query.get_setup_params(),
self.fairness = FairnessControllerClient.from_setup_query(
query=query, manager=self.trainmanager
)
await self.fairness.setup_fairness(
netwk=self.netwk, secagg=self._encrypter
)
except Exception as exc:
error = (
......@@ -562,7 +550,7 @@ class FederatedClient:
if self._encrypter is not None and isinstance(
reply, messaging.TrainReply
):
reply = SecaggTrainReply.from_cleartext_message(
reply = secagg_messaging.SecaggTrainReply.from_cleartext_message(
cleartext=reply, encrypter=self._encrypter
)
# Send training results (or error message) to the server.
......@@ -613,7 +601,8 @@ class FederatedClient:
reply.metrics.clear()
# Optionally SecAgg-encrypt results.
if self._encrypter is not None:
reply = SecaggEvaluationReply.from_cleartext_message(
msg_cls = secagg_messaging.SecaggEvaluationReply
reply = msg_cls.from_cleartext_message(
cleartext=reply, encrypter=self._encrypter
)
# Send evaluation results (or error message) to the server.
......@@ -651,12 +640,17 @@ class FederatedClient:
await self.netwk.send_message(messaging.Error(error))
raise RuntimeError(error)
# Otherwise, run the controller's routine.
await self.fairness.fairness_round(
netwk=self.netwk,
query=query,
manager=self.trainmanager,
secagg=self._encrypter,
metrics = await self.fairness.fairness_round(
netwk=self.netwk, query=query, secagg=self._encrypter
)
# Optionally save computed fairness metrics.
if self.ckptr is not None:
self.ckptr.save_metrics(
metrics=metrics,
prefix="fairness_metrics",
append=(query.round_i > 0),
timestamp=f"round_{query.round_i}",
)
async def stop_training(
self,
......
......@@ -579,12 +579,20 @@ class FederatedServer:
)
values = self._aggregate_secagg_replies(secagg_replies).values
# Have the fairness controller process results.
await self.fairness.finalize_fairness_round(
metrics = await self.fairness.finalize_fairness_round(
round_i=round_i,
values=values,
netwk=self.netwk,
secagg=self._decrypter,
)
# Optionally save computed fairness metrics.
if self.ckptr is not None:
self.ckptr.save_metrics(
metrics=metrics,
prefix="fairness_metrics",
append=(query.round_i > 0),
timestamp=f"round_{query.round_i}",
)
async def training_round(
self,
......
......@@ -33,8 +33,6 @@ Base messages
* [Error][declearn.messaging.Error]
* [EvaluationReply][declearn.messaging.EvaluationReply]
* [EvaluationRequest][declearn.messaging.EvaluationRequest]
* [FairnessQuery][declearn.messaging.FairnessQuery]
* [FairnessReply][declearn.messaging.FairnessReply]
* [GenericMessage][declearn.messaging.GenericMessage]
* [InitRequest][declearn.messaging.InitRequest]
* [InitReply][declearn.messaging.InitReply]
......@@ -46,6 +44,14 @@ Base messages
* [TrainReply][declearn.messaging.TrainReply]
* [TrainRequest][declearn.messaging.TrainRequest]
Fairness algorithms messages
----------------------------
* [FairnessCounts][declearn.messaging.FairnessCounts]
* [FairnessGroups][declearn.messaging.FairnessGroups]
* [FairnessQuery][declearn.messaging.FairnessQuery]
* [FairnessReply][declearn.messaging.FairnessReply]
* [FairnessSetupQuery][declearn.messaging.FairnessSetupQuery]
"""
from ._api import (
......@@ -57,8 +63,6 @@ from ._base import (
Error,
EvaluationReply,
EvaluationRequest,
FairnessQuery,
FairnessReply,
GenericMessage,
InitRequest,
InitReply,
......@@ -70,3 +74,10 @@ from ._base import (
TrainReply,
TrainRequest,
)
from ._fairness import (
FairnessCounts,
FairnessGroups,
FairnessQuery,
FairnessReply,
FairnessSetupQuery,
)
......@@ -36,8 +36,6 @@ __all__ = [
"Error",
"EvaluationReply",
"EvaluationRequest",
"FairnessQuery",
"FairnessReply",
"GenericMessage",
"InitRequest",
"InitReply",
......@@ -102,44 +100,6 @@ class EvaluationReply(Message):
return kwargs
@dataclasses.dataclass
class FairnessQuery(Message):
"""Base Message for server-emitted fairness-computation queries.
This message conveys hyper-parameters used when evaluating a model's
accuracy and/or loss over group-wise samples (from which fairness is
derived). Model weights may be attached.
Algorithm-specific information should be conveyed using ad-hoc
messages exchanged as part of fairness-enforcement routines.
"""
typekey = "fairness-request"
round_i: int
batch_size: int = 32
n_batch: Optional[int] = None
thresh: Optional[float] = None
weights: Optional[Vector] = None
@dataclasses.dataclass
class FairnessReply(Message):
"""Base Message for client-emitted fairness-computation results.
This message conveys results from the evaluation of a model's accuracy
and/or loss over group-wise samples (from which fairness is derived).
This information is generically stored as a list of `values`, the
mearning and structure of which is left up to algorithm-specific
controllers.
"""
typekey = "fairness-reply"
values: List[float] = dataclasses.field(default_factory=list)
@dataclasses.dataclass
class GenericMessage(Message):
"""Generic message format, with action/params pair."""
......
......@@ -15,78 +15,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""API messages for fairness-aware federated learning setup and rounds."""
"""Messages for fairness-aware federated learning setup and rounds."""
import dataclasses
from typing import Any, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from typing_extensions import Self # future: import from typing (py >=3.11)
from declearn.messaging import Message
from declearn.secagg.api import Decrypter, Encrypter
from declearn.secagg.messaging import SecaggMessage
from declearn.model.api import Vector
__all__ = [
"FairnessAccuracy",
"FairnessCounts",
"FairnessGroups",
"SecaggFairnessAccuracy",
"SecaggFairnessCounts",
"FairnessQuery",
"FairnessReply",
"FairnessSetupQuery",
]
@dataclasses.dataclass
class FairnessAccuracy(Message):
"""Message for client-emitted model accuracy across sensitive groups.
Fields
------
values:
List of group-wise accuracy values, ordered based
on an agreed-upon sorted list of sensitive groups.
"""
values: List[float]
typekey = "fairness-accuracy"
@dataclasses.dataclass
class SecaggFairnessAccuracy(SecaggMessage[FairnessAccuracy]):
"""SecAgg counterpart of the 'FairnessAccuracy' message class."""
values: List[int]
typekey = "secagg-fairness-accuracy"
@classmethod
def from_cleartext_message(
cls,
cleartext: FairnessAccuracy,
encrypter: Encrypter,
) -> Self:
values = [encrypter.encrypt_float(val) for val in cleartext.values]
return cls(values=values)
def decrypt_wrapped_message(
self,
decrypter: Decrypter,
) -> FairnessAccuracy:
values = [decrypter.decrypt_float(val) for val in self.values]
return FairnessAccuracy(values=values)
def aggregate(
self,
other: Self,
decrypter: Decrypter,
) -> Self:
values = [
decrypter.sum_encrypted([v_a, v_b])
for v_a, v_b in zip(self.values, other.values)
]
return self.__class__(values=values)
@dataclasses.dataclass
class FairnessCounts(Message):
"""Message for client-emitted sample counts across sensitive groups.
......@@ -103,42 +50,6 @@ class FairnessCounts(Message):
typekey = "fairness-counts"
@dataclasses.dataclass
class SecaggFairnessCounts(SecaggMessage[FairnessCounts]):
"""SecAgg counterpart of the 'FairnessCounts' message class."""
counts: List[int]
typekey = "secagg-fairness-counts"
@classmethod
def from_cleartext_message(
cls,
cleartext: FairnessCounts,
encrypter: Encrypter,
) -> Self:
counts = [encrypter.encrypt_uint(val) for val in cleartext.counts]
return cls(counts=counts)
def decrypt_wrapped_message(
self,
decrypter: Decrypter,
) -> FairnessCounts:
counts = [decrypter.decrypt_uint(val) for val in self.counts]
return FairnessCounts(counts=counts)
def aggregate(
self,
other: Self,
decrypter: Decrypter,
) -> Self:
counts = [
decrypter.sum_encrypted([v_a, v_b])
for v_a, v_b in zip(self.counts, other.counts)
]
return self.__class__(counts=counts)
@dataclasses.dataclass
class FairnessGroups(Message):
"""Message to exchange a list of unique sensitive group definitions.
......@@ -166,3 +77,60 @@ class FairnessGroups(Message):
) -> Self:
kwargs["groups"] = [tuple(group) for group in kwargs["groups"]]
return super().from_kwargs(**kwargs)
@dataclasses.dataclass
class FairnessQuery(Message):
"""Base Message for server-emitted fairness-computation queries.
This message conveys hyper-parameters used when evaluating a model's
accuracy and/or loss over group-wise samples (from which fairness is
derived). Model weights may be attached.
Algorithm-specific information should be conveyed using ad-hoc
messages exchanged as part of fairness-enforcement routines.
"""
typekey = "fairness-request"
round_i: int
batch_size: int = 32
n_batch: Optional[int] = None
thresh: Optional[float] = None
weights: Optional[Vector] = None
@dataclasses.dataclass
class FairnessReply(Message):
"""Base Message for client-emitted fairness-computation results.
This message conveys results from the evaluation of a model's accuracy
and/or loss over group-wise samples (from which fairness is derived).
This information is generically stored as a list of `values`, the
mearning and structure of which is left up to algorithm-specific
controllers.
"""
typekey = "fairness-reply"
values: List[float] = dataclasses.field(default_factory=list)
@dataclasses.dataclass
class FairnessSetupQuery(Message):
"""Message to instruct clients to instantiate a fairness controller.
Fields
------
algorithm:
Name of the algorithm, under which the target controller type
is expected to be registered.
params:
Dict of instantiation keyword arguments to the controller.
"""
typekey = "fairness-setup-query"
algorithm: str
params: Dict[str, Any] = dataclasses.field(default_factory=dict)
......@@ -26,6 +26,7 @@ from typing_extensions import Self # future: import from typing (py >=3.11)
from declearn.aggregator import ModelUpdates
from declearn.messaging import (
EvaluationReply,
FairnessCounts,
FairnessReply,
Message,
TrainReply,
......@@ -36,6 +37,7 @@ from declearn.secagg.api import Decrypter, Encrypter, SecureAggregate
__all__ = [
"SecaggEvaluationReply",
"SecaggFairnessCounts",
"SecaggFairnessReply",
"SecaggMessage",
"SecaggTrainReply",
......@@ -264,6 +266,42 @@ class SecaggEvaluationReply(SecaggMessage[EvaluationReply]):
)
@dataclasses.dataclass
class SecaggFairnessCounts(SecaggMessage[FairnessCounts]):
"""SecAgg counterpart of the 'FairnessCounts' message class."""
counts: List[int]
typekey = "secagg-fairness-counts"
@classmethod
def from_cleartext_message(
cls,
cleartext: FairnessCounts,
encrypter: Encrypter,
) -> Self:
counts = [encrypter.encrypt_uint(val) for val in cleartext.counts]
return cls(counts=counts)
def decrypt_wrapped_message(
self,
decrypter: Decrypter,
) -> FairnessCounts:
counts = [decrypter.decrypt_uint(val) for val in self.counts]
return FairnessCounts(counts=counts)
def aggregate(
self,
other: Self,
decrypter: Decrypter,
) -> Self:
counts = [
decrypter.sum_encrypted([v_a, v_b])
for v_a, v_b in zip(self.counts, other.counts)
]
return self.__class__(counts=counts)
@dataclasses.dataclass
class SecaggFairnessReply(SecaggMessage[FairnessReply]):
"""SecAgg-wrapped 'FairnessReply' message."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment