Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 4cd7d94d authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Refactor some fairness controllers code.

- Expose some subroutines under setup and fairness round,
  for the mere sake of making tests easier to perform, as
  well as to enable variants over current algorithms in
  the future / in experiments.
- Rename some methods and re-order some arguments.
- Refactor server-side aggregation of metrics, making it
  part of 'FairnessControllerServer' rather than part of
  'FederatedServer' backend code.
parent 21368711
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
...@@ -161,12 +161,36 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): ...@@ -161,12 +161,36 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
secagg: secagg:
Optional SecAgg encryption controller. Optional SecAgg encryption controller.
""" """
# Agree on a list of sensitive groups and share local sample counts.
await self.exchange_sensitive_groups_list_and_counts(netwk, secagg)
# Run additional algorithm-specific setup steps.
await self.finalize_fairness_setup(netwk, secagg)
async def exchange_sensitive_groups_list_and_counts(
self,
netwk: NetworkClient,
secagg: Optional[Encrypter],
) -> None:
"""Agree on a list of sensitive groups and share local sample counts.
This method performs the following routine:
- Send the list of local sensitive group definitions to the server.
- Await a unified list of sensitive groups in return.
- Assign the received list as `groups` attribute.
- Send (optionally-encrypted) group-wise sample counts to the server.
Parameters
----------
netwk:
`NetworkClient` endpoint, connected to a server.
secagg:
Optional SecAgg encryption controller.
"""
# Share sensitive groups definitions and received an ordered list. # Share sensitive groups definitions and received an ordered list.
self.groups = await self._exchange_sensitive_groups_list(netwk) self.groups = await self._exchange_sensitive_groups_list(netwk)
# Send group-wise sample counts for the server to (secure-)aggregate. # Send group-wise sample counts for the server to (secure-)aggregate.
await self._send_sensitive_groups_counts(netwk, secagg) await self._send_sensitive_groups_counts(netwk, secagg)
# Run additional algorithm-specific setup steps.
await self.finalize_fairness_setup(netwk, secagg)
async def _exchange_sensitive_groups_list( async def _exchange_sensitive_groups_list(
self, self,
...@@ -220,7 +244,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): ...@@ -220,7 +244,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
Optional SecAgg encryption controller. Optional SecAgg encryption controller.
""" """
async def fairness_round( async def run_fairness_round(
self, self,
netwk: NetworkClient, netwk: NetworkClient,
query: FairnessQuery, query: FairnessQuery,
...@@ -253,7 +277,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): ...@@ -253,7 +277,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
await netwk.send_message(Error(error)) await netwk.send_message(Error(error))
raise RuntimeError(error) from exc raise RuntimeError(error) from exc
# Run additional algorithm-specific steps. # Run additional algorithm-specific steps.
return await self.finalize_fairness_round(netwk, values, secagg) return await self.finalize_fairness_round(netwk, secagg, values)
async def _compute_and_share_fairness_measures( async def _compute_and_share_fairness_measures(
self, self,
...@@ -374,8 +398,8 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): ...@@ -374,8 +398,8 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
async def finalize_fairness_round( async def finalize_fairness_round(
self, self,
netwk: NetworkClient, netwk: NetworkClient,
values: Dict[str, Dict[Tuple[Any, ...], float]],
secagg: Optional[Encrypter], secagg: Optional[Encrypter],
values: Dict[str, Dict[Tuple[Any, ...], float]],
) -> Dict[str, Union[float, np.ndarray]]: ) -> Dict[str, Union[float, np.ndarray]]:
"""Take actions to enforce fairness. """Take actions to enforce fairness.
...@@ -387,13 +411,13 @@ class FairnessControllerClient(metaclass=abc.ABCMeta): ...@@ -387,13 +411,13 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
---------- ----------
netwk: netwk:
NetworkClient endpoint instance, connected to a server. NetworkClient endpoint instance, connected to a server.
secagg:
Optional SecAgg encryption controller.
values: values:
Nested dict of locally-computed group-wise metrics. Nested dict of locally-computed group-wise metrics.
This is the second set of `compute_fairness_measures` return This is the second set of `compute_fairness_measures` return
values; when this method is called, the first has already values; when this method is called, the first has already
been shared with the server for (secure-)aggregation. been shared with the server for (secure-)aggregation.
secagg:
Optional SecAgg encryption controller.
Returns Returns
------- -------
......
...@@ -26,8 +26,10 @@ from declearn.aggregator import Aggregator ...@@ -26,8 +26,10 @@ from declearn.aggregator import Aggregator
from declearn.communication.api import NetworkServer from declearn.communication.api import NetworkServer
from declearn.communication.utils import verify_client_messages_validity from declearn.communication.utils import verify_client_messages_validity
from declearn.messaging import ( from declearn.messaging import (
Error,
FairnessCounts, FairnessCounts,
FairnessGroups, FairnessGroups,
FairnessReply,
FairnessSetupQuery, FairnessSetupQuery,
SerializedMessage, SerializedMessage,
) )
...@@ -35,6 +37,7 @@ from declearn.secagg.api import Decrypter ...@@ -35,6 +37,7 @@ from declearn.secagg.api import Decrypter
from declearn.secagg.messaging import ( from declearn.secagg.messaging import (
aggregate_secagg_messages, aggregate_secagg_messages,
SecaggFairnessCounts, SecaggFairnessCounts,
SecaggFairnessReply,
) )
from declearn.utils import create_types_registry, register_type from declearn.utils import create_types_registry, register_type
...@@ -67,7 +70,7 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): ...@@ -67,7 +70,7 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
def __init__( def __init__(
self, self,
f_type: str, f_type: str,
f_args: Optional[Dict[str, Any]], f_args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
"""Instantiate the server-side fairness controller. """Instantiate the server-side fairness controller.
...@@ -82,6 +85,8 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): ...@@ -82,6 +85,8 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
self.f_args = f_args or {} self.f_args = f_args or {}
self.groups = [] # type: List[Tuple[Any, ...]] self.groups = [] # type: List[Tuple[Any, ...]]
# Fairness Setup methods.
async def setup_fairness( async def setup_fairness(
self, self,
netwk: NetworkServer, netwk: NetworkServer,
...@@ -126,12 +131,14 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): ...@@ -126,12 +131,14 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
# Send a setup query to all clients. # Send a setup query to all clients.
query = self.prepare_fairness_setup_query() query = self.prepare_fairness_setup_query()
await netwk.broadcast_message(query) await netwk.broadcast_message(query)
# Receive, aggregate, assign and send back sensitive group definitions. # Agree on a list of sensitive groups and aggregate sample counts.
self.groups = await self._exchange_sensitive_groups_list(netwk) counts = await self.exchange_sensitive_groups_list_and_counts(
# Receive, (secure-)aggregate and return group-wise sample counts. netwk, secagg
counts = await self._aggregate_sensitive_groups_counts(netwk, secagg) )
# Run additional algorithm-specific setup steps. # Run additional algorithm-specific setup steps.
return await self.finalize_fairness_setup(netwk, counts, aggregator) return await self.finalize_fairness_setup(
netwk, secagg, counts, aggregator
)
def prepare_fairness_setup_query( def prepare_fairness_setup_query(
self, self,
...@@ -149,6 +156,40 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): ...@@ -149,6 +156,40 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
params={"f_type": self.f_type, "f_args": self.f_args}, params={"f_type": self.f_type, "f_args": self.f_args},
) )
async def exchange_sensitive_groups_list_and_counts(
self,
netwk: NetworkServer,
secagg: Optional[Decrypter],
) -> List[int]:
"""Agree on a list of sensitive groups and aggregate sample counts.
This method performs the following routine:
- Await `FairnessGroups` messages from clients with group definitions.
- Assign a sorted list of sensitive groups as `groups` attribute.
- Share that list with clients.
- Await possibly-encrypted group-wise sample counts from clients.
- (Secure-)Aggregate these sample counts and return them.
Parameters
----------
netwk:
`NetworkServer` endpoint, through which a fairness setup query
was previously sent to all clients.
secagg:
Optional SecAgg decryption controller.
Returns
-------
counts:
List of group-wise total sample count across clients,
sorted based on the newly-assigned `self.groups`.
"""
# Receive, aggregate, assign and send back sensitive group definitions.
self.groups = await self._exchange_sensitive_groups_list(netwk)
# Receive, (secure-)aggregate and return group-wise sample counts.
return await self._aggregate_sensitive_groups_counts(netwk, secagg)
@staticmethod @staticmethod
async def _exchange_sensitive_groups_list( async def _exchange_sensitive_groups_list(
netwk: NetworkServer, netwk: NetworkServer,
...@@ -213,6 +254,7 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): ...@@ -213,6 +254,7 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
async def finalize_fairness_setup( async def finalize_fairness_setup(
self, self,
netwk: NetworkServer, netwk: NetworkServer,
secagg: Optional[Decrypter],
counts: List[int], counts: List[int],
aggregator: Aggregator, aggregator: Aggregator,
) -> Aggregator: ) -> Aggregator:
...@@ -238,13 +280,90 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): ...@@ -238,13 +280,90 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
or may not have been altered compared with the input one. or may not have been altered compared with the input one.
""" """
# Fairness Round methods.
async def run_fairness_round(
self,
netwk: NetworkServer,
secagg: Optional[Decrypter],
) -> Dict[str, Union[float, np.ndarray]]:
"""Secure-aggregate and post-process fairness measures.
This method is to be run **after** having sent a `FairnessQuery`
to clients. It consists in receiving, (secure-)aggregating and
post-processing measures that clients produce as a reply to that
query. This may involve further algorithm-specific communications.
Parameters
----------
netwk:
NetworkServer endpoint instance, to which clients are registered.
secagg:
Optional SecAgg decryption controller.
Returns
-------
metrics:
Fairness(-related) metrics computed as part of this routine,
as a dict mapping scalar or numpy array values with their name.
"""
values = await self.receive_and_aggregate_fairness_measures(
netwk, secagg
)
return await self.finalize_fairness_round(netwk, secagg, values)
async def receive_and_aggregate_fairness_measures(
self,
netwk: NetworkServer,
secagg: Optional[Decrypter],
) -> List[float]:
"""Await and (secure-)aggregate client-wise fairness-related metrics.
This method is designed to be called after sending a `FairnessQuery`
to clients, and returns values that are yet to be parsed and used by
the algorithm-dependent `finalize_fairness_round` method.
Parameters
----------
netwk:
NetworkServer endpoint instance, to which clients are registered.
secagg:
Optional SecAgg decryption controller.
Returns
-------
metrics:
List of sum-aggregated fairness-related metrics (as floats).
By default, these are group-wise accuracy values; this may
however be changed or expanded by algorithm-specific classes.
"""
received = await netwk.wait_for_messages()
# Case when expecting cleartext values.
if secagg is None:
replies = await verify_client_messages_validity(
netwk, received, expected=FairnessReply
)
if len(set(len(r.values) for r in replies.values())) != 1:
error = "Clients sent fairness values of different lengths."
await netwk.broadcast_message(Error(error))
raise RuntimeError(error)
return [
sum(rval)
for rval in zip(*[reply.values for reply in replies.values()])
]
# Case when expecting encrypted values.
secagg_replies = await verify_client_messages_validity(
netwk, received, expected=SecaggFairnessReply
)
agg_reply = aggregate_secagg_messages(secagg_replies, decrypter=secagg)
return agg_reply.values
@abc.abstractmethod @abc.abstractmethod
async def finalize_fairness_round( async def finalize_fairness_round(
self, self,
round_i: int,
values: List[float],
netwk: NetworkServer, netwk: NetworkServer,
secagg: Optional[Decrypter], secagg: Optional[Decrypter],
values: List[float],
) -> Dict[str, Union[float, np.ndarray]]: ) -> Dict[str, Union[float, np.ndarray]]:
"""Orchestrate a round of actions to enforce fairness. """Orchestrate a round of actions to enforce fairness.
...@@ -254,21 +373,17 @@ class FairnessControllerServer(metaclass=abc.ABCMeta): ...@@ -254,21 +373,17 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
Parameters Parameters
---------- ----------
round_i:
Index of the current round (reflecting that of an upcoming
training round).
values:
Aggregated metrics resulting from the fairness evaluation
run by clients at this round.
netwk: netwk:
NetworkServer endpoint instance, to which clients are registered. NetworkServer endpoint instance, to which clients are registered.
secagg: secagg:
Optional SecAgg decryption controller. Optional SecAgg decryption controller.
values:
Aggregated metrics resulting from the fairness evaluation
run by clients at this round.
Returns Returns
------- -------
metrics: metrics:
Computed local fairness(-related) metrics computed as part Fairness(-related) metrics computed as part of this routine,
of this routine, as a dict mapping scalar or numpy array as a dict mapping scalar or numpy array values with their name.
values with their name.
""" """
...@@ -84,7 +84,7 @@ class FairbatchControllerClient(FairnessControllerClient): ...@@ -84,7 +84,7 @@ class FairbatchControllerClient(FairnessControllerClient):
If the sampling pobabilities' update fails. If the sampling pobabilities' update fails.
""" """
# Receive aggregated sensitive weights. # Receive aggregated sensitive weights.
received = await netwk.check_message() received = await netwk.recv_message()
message = await verify_server_message_validity( message = await verify_server_message_validity(
netwk, received, expected=FairbatchSamplingProbas netwk, received, expected=FairbatchSamplingProbas
) )
...@@ -114,15 +114,15 @@ class FairbatchControllerClient(FairnessControllerClient): ...@@ -114,15 +114,15 @@ class FairbatchControllerClient(FairnessControllerClient):
thresh: Optional[float] = None, thresh: Optional[float] = None,
) -> List[MeanMetric]: ) -> List[MeanMetric]:
loss = self.computer.setup_loss_metric(model=self.manager.model) loss = self.computer.setup_loss_metric(model=self.manager.model)
metrics = super().setup_fairness_metrics() metrics = super().setup_fairness_metrics(thresh=thresh)
metrics.append(loss) metrics.append(loss)
return metrics return metrics
async def finalize_fairness_round( async def finalize_fairness_round(
self, self,
netwk: NetworkClient, netwk: NetworkClient,
values: Dict[str, Dict[Tuple[Any, ...], float]],
secagg: Optional[Encrypter], secagg: Optional[Encrypter],
values: Dict[str, Dict[Tuple[Any, ...], float]],
) -> Dict[str, Union[float, np.ndarray]]: ) -> Dict[str, Union[float, np.ndarray]]:
# Await updated loss weights from the server. # Await updated loss weights from the server.
await self._update_fairbatch_sampling_probas(netwk) await self._update_fairbatch_sampling_probas(netwk)
......
...@@ -107,6 +107,7 @@ class FairbatchControllerServer(FairnessControllerServer): ...@@ -107,6 +107,7 @@ class FairbatchControllerServer(FairnessControllerServer):
async def finalize_fairness_setup( async def finalize_fairness_setup(
self, self,
netwk: NetworkServer, netwk: NetworkServer,
secagg: Optional[Decrypter],
counts: List[int], counts: List[int],
aggregator: Aggregator, aggregator: Aggregator,
) -> Aggregator: ) -> Aggregator:
...@@ -150,10 +151,9 @@ class FairbatchControllerServer(FairnessControllerServer): ...@@ -150,10 +151,9 @@ class FairbatchControllerServer(FairnessControllerServer):
async def finalize_fairness_round( async def finalize_fairness_round(
self, self,
round_i: int,
values: List[float],
netwk: NetworkServer, netwk: NetworkServer,
secagg: Optional[Decrypter], secagg: Optional[Decrypter],
values: List[float],
) -> Dict[str, Union[float, np.ndarray]]: ) -> Dict[str, Union[float, np.ndarray]]:
# Unpack group-wise accuracy and loss values. # Unpack group-wise accuracy and loss values.
accuracy = dict(zip(self.groups, values[: len(self.groups)])) accuracy = dict(zip(self.groups, values[: len(self.groups)]))
......
...@@ -106,8 +106,8 @@ class FairfedControllerClient(FairnessControllerClient): ...@@ -106,8 +106,8 @@ class FairfedControllerClient(FairnessControllerClient):
async def finalize_fairness_round( async def finalize_fairness_round(
self, self,
netwk: NetworkClient, netwk: NetworkClient,
values: Dict[str, Dict[Tuple[Any, ...], float]],
secagg: Optional[Encrypter], secagg: Optional[Encrypter],
values: Dict[str, Dict[Tuple[Any, ...], float]],
) -> Dict[str, Union[float, np.ndarray]]: ) -> Dict[str, Union[float, np.ndarray]]:
# Await absolute mean fairness across all clients. # Await absolute mean fairness across all clients.
received = await netwk.recv_message() received = await netwk.recv_message()
......
...@@ -113,6 +113,7 @@ class FairfedControllerServer(FairnessControllerServer): ...@@ -113,6 +113,7 @@ class FairfedControllerServer(FairnessControllerServer):
async def finalize_fairness_setup( async def finalize_fairness_setup(
self, self,
netwk: NetworkServer, netwk: NetworkServer,
secagg: Optional[Decrypter],
counts: List[int], counts: List[int],
aggregator: Aggregator, aggregator: Aggregator,
) -> Aggregator: ) -> Aggregator:
...@@ -130,10 +131,9 @@ class FairfedControllerServer(FairnessControllerServer): ...@@ -130,10 +131,9 @@ class FairfedControllerServer(FairnessControllerServer):
async def finalize_fairness_round( async def finalize_fairness_round(
self, self,
round_i: int,
values: List[float],
netwk: NetworkServer, netwk: NetworkServer,
secagg: Optional[Decrypter], secagg: Optional[Decrypter],
values: List[float],
) -> Dict[str, Union[float, np.ndarray]]: ) -> Dict[str, Union[float, np.ndarray]]:
# Unpack group-wise accuracy values and compute fairness ones. # Unpack group-wise accuracy values and compute fairness ones.
accuracy = dict(zip(self.groups, values)) accuracy = dict(zip(self.groups, values))
......
...@@ -69,7 +69,7 @@ class FairgradControllerClient(FairnessControllerClient): ...@@ -69,7 +69,7 @@ class FairgradControllerClient(FairnessControllerClient):
If the weights' update fails. If the weights' update fails.
""" """
# Receive aggregated sensitive weights. # Receive aggregated sensitive weights.
received = await netwk.check_message() received = await netwk.recv_message()
message = await verify_server_message_validity( message = await verify_server_message_validity(
netwk, received, expected=FairgradWeights netwk, received, expected=FairgradWeights
) )
...@@ -93,8 +93,8 @@ class FairgradControllerClient(FairnessControllerClient): ...@@ -93,8 +93,8 @@ class FairgradControllerClient(FairnessControllerClient):
async def finalize_fairness_round( async def finalize_fairness_round(
self, self,
netwk: NetworkClient, netwk: NetworkClient,
values: Dict[str, Dict[Tuple[Any, ...], float]],
secagg: Optional[Encrypter], secagg: Optional[Encrypter],
values: Dict[str, Dict[Tuple[Any, ...], float]],
) -> Dict[str, Union[float, np.ndarray]]: ) -> Dict[str, Union[float, np.ndarray]]:
# Await updated loss weights from the server. # Await updated loss weights from the server.
await self._update_fairgrad_weights(netwk) await self._update_fairgrad_weights(netwk)
......
...@@ -189,6 +189,7 @@ class FairgradControllerServer(FairnessControllerServer): ...@@ -189,6 +189,7 @@ class FairgradControllerServer(FairnessControllerServer):
async def finalize_fairness_setup( async def finalize_fairness_setup(
self, self,
netwk: NetworkServer, netwk: NetworkServer,
secagg: Optional[Decrypter],
counts: List[int], counts: List[int],
aggregator: Aggregator, aggregator: Aggregator,
) -> Aggregator: ) -> Aggregator:
...@@ -230,10 +231,9 @@ class FairgradControllerServer(FairnessControllerServer): ...@@ -230,10 +231,9 @@ class FairgradControllerServer(FairnessControllerServer):
async def finalize_fairness_round( async def finalize_fairness_round(
self, self,
round_i: int,
values: List[float],
netwk: NetworkServer, netwk: NetworkServer,
secagg: Optional[Decrypter], secagg: Optional[Decrypter],
values: List[float],
) -> Dict[str, Union[float, np.ndarray]]: ) -> Dict[str, Union[float, np.ndarray]]:
# Unpack group-wise accuracy metrics and update loss weights. # Unpack group-wise accuracy metrics and update loss weights.
accuracy = dict(zip(self.groups, values)) accuracy = dict(zip(self.groups, values))
......
...@@ -45,8 +45,8 @@ class FairnessMonitorClient(FairnessControllerClient): ...@@ -45,8 +45,8 @@ class FairnessMonitorClient(FairnessControllerClient):
async def finalize_fairness_round( async def finalize_fairness_round(
self, self,
netwk: NetworkClient, netwk: NetworkClient,
values: Dict[str, Dict[Tuple[Any, ...], float]],
secagg: Optional[Encrypter], secagg: Optional[Encrypter],
values: Dict[str, Dict[Tuple[Any, ...], float]],
) -> Dict[str, Union[float, np.ndarray]]: ) -> Dict[str, Union[float, np.ndarray]]:
return { return {
f"{metric}_{group}": value f"{metric}_{group}": value
......
...@@ -42,7 +42,7 @@ class FairnessMonitorServer(FairnessControllerServer): ...@@ -42,7 +42,7 @@ class FairnessMonitorServer(FairnessControllerServer):
def __init__( def __init__(
self, self,
f_type: str, f_type: str,
f_args: Optional[Dict[str, Any]], f_args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
super().__init__(f_type, f_args) super().__init__(f_type, f_args)
# Assign a temporary fairness functions, replaced at setup time. # Assign a temporary fairness functions, replaced at setup time.
...@@ -53,6 +53,7 @@ class FairnessMonitorServer(FairnessControllerServer): ...@@ -53,6 +53,7 @@ class FairnessMonitorServer(FairnessControllerServer):
async def finalize_fairness_setup( async def finalize_fairness_setup(
self, self,
netwk: NetworkServer, netwk: NetworkServer,
secagg: Optional[Decrypter],
counts: List[int], counts: List[int],
aggregator: Aggregator, aggregator: Aggregator,
) -> Aggregator: ) -> Aggregator:
...@@ -65,10 +66,9 @@ class FairnessMonitorServer(FairnessControllerServer): ...@@ -65,10 +66,9 @@ class FairnessMonitorServer(FairnessControllerServer):
async def finalize_fairness_round( async def finalize_fairness_round(
self, self,
round_i: int,
values: List[float],
netwk: NetworkServer, netwk: NetworkServer,
secagg: Optional[Decrypter], secagg: Optional[Decrypter],
values: List[float],
) -> Dict[str, Union[float, np.ndarray]]: ) -> Dict[str, Union[float, np.ndarray]]:
# Unpack group-wise accuracy metrics and compute fairness ones. # Unpack group-wise accuracy metrics and compute fairness ones.
accuracy = dict(zip(self.groups, values)) accuracy = dict(zip(self.groups, values))
......
...@@ -660,7 +660,7 @@ class FederatedClient: ...@@ -660,7 +660,7 @@ class FederatedClient:
await self.netwk.send_message(messaging.Error(error)) await self.netwk.send_message(messaging.Error(error))
return return
# Otherwise, run the controller's routine. # Otherwise, run the controller's routine.
metrics = await self.fairness.fairness_round( metrics = await self.fairness.run_fairness_round(
netwk=self.netwk, query=query, secagg=self._encrypter netwk=self.netwk, query=query, secagg=self._encrypter
) )
# Optionally save computed fairness metrics. # Optionally save computed fairness metrics.
......
...@@ -565,27 +565,8 @@ class FederatedServer: ...@@ -565,27 +565,8 @@ class FederatedServer:
weights=None, weights=None,
) )
await self._send_request_with_optional_weights(query, clients) await self._send_request_with_optional_weights(query, clients)
# Await and (secure-)aggregate) results. # Await, (secure-)aggregate and process fairness measures.
self.logger.info("Awaiting clients' fairness measures.") metrics = await self.fairness.run_fairness_round(
if self._decrypter is None:
replies = await self._collect_results(
clients, messaging.FairnessReply, "fairness round"
)
if len(set(len(r.values) for r in replies.values())) != 1:
error = "Clients sent fairness values of different lengths."
self.logger.error(error)
await self.netwk.broadcast_message(messaging.Error(error))
raise RuntimeError(error)
values = [sum(c_values) for c_values in zip(*replies.values())]
else:
secagg_replies = await self._collect_results(
clients, secagg_messaging.SecaggFairnessReply, "fairness round"
)
values = self._aggregate_secagg_replies(secagg_replies).values
# Have the fairness controller process results.
metrics = await self.fairness.finalize_fairness_round(
round_i=round_i,
values=values,
netwk=self.netwk, netwk=self.netwk,
secagg=self._decrypter, secagg=self._decrypter,
) )
......
...@@ -1067,13 +1067,13 @@ class TestFederatedClientFairnessRound: ...@@ -1067,13 +1067,13 @@ class TestFederatedClientFairnessRound:
# Call the 'fairness_round' routine and verify expected actions. # Call the 'fairness_round' routine and verify expected actions.
request = messaging.FairnessQuery(round_i=1) request = messaging.FairnessQuery(round_i=1)
await client.fairness_round(request) await client.fairness_round(request)
fairness.fairness_round.assert_awaited_once_with( fairness.run_fairness_round.assert_awaited_once_with(
netwk=netwk, query=request, secagg=None netwk=netwk, query=request, secagg=None
) )
# Verify that when a checkpointer is set, it is used. # Verify that when a checkpointer is set, it is used.
if ckpt: if ckpt:
client.ckptr.save_metrics.assert_called_once_with( # type: ignore client.ckptr.save_metrics.assert_called_once_with( # type: ignore
metrics=fairness.fairness_round.return_value, metrics=fairness.run_fairness_round.return_value,
prefix="fairness_metrics", prefix="fairness_metrics",
append=True, append=True,
timestamp="round_1", timestamp="round_1",
...@@ -1102,7 +1102,7 @@ class TestFederatedClientFairnessRound: ...@@ -1102,7 +1102,7 @@ class TestFederatedClientFairnessRound:
# Call the 'fairness_round' routine and verify expected actions. # Call the 'fairness_round' routine and verify expected actions.
request = messaging.FairnessQuery(round_i=1) request = messaging.FairnessQuery(round_i=1)
await client.fairness_round(request) await client.fairness_round(request)
fairness.fairness_round.assert_awaited_once_with( fairness.run_fairness_round.assert_awaited_once_with(
netwk=netwk, netwk=netwk,
query=request, query=request,
secagg=secagg.setup_encrypter.return_value, secagg=secagg.setup_encrypter.return_value,
...@@ -1149,7 +1149,7 @@ class TestFederatedClientFairnessRound: ...@@ -1149,7 +1149,7 @@ class TestFederatedClientFairnessRound:
netwk.send_message.assert_called_once() netwk.send_message.assert_called_once()
reply = netwk.send_message.call_args.args[0] reply = netwk.send_message.call_args.args[0]
assert isinstance(reply, messaging.Error) assert isinstance(reply, messaging.Error)
fairness.fairness_round.assert_not_called() fairness.run_fairness_round.assert_not_called()
class TestFederatedClientMisc: class TestFederatedClientMisc:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment