From 4cd7d94d737bd9cb470a8fb7196f62d3a97ffccc Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 5 Jul 2024 15:18:49 +0200
Subject: [PATCH] Refactor some fairness controllers code.

- Expose some subroutines under setup and fairness round,
  for the mere sake of making tests easier to perform, as
  well as to enable variants over current algorithms in
  the future / in experiments.
- Rename some methods and re-order some arguments.
- Refactor server-side aggregation of metrics, making it
  part of 'FairnessControllerServer' rather than part of
  'FederatedServer' backend code.
---
 declearn/fairness/api/_client.py       |  38 +++++--
 declearn/fairness/api/_server.py       | 149 ++++++++++++++++++++++---
 declearn/fairness/fairbatch/_client.py |   6 +-
 declearn/fairness/fairbatch/_server.py |   4 +-
 declearn/fairness/fairfed/_client.py   |   2 +-
 declearn/fairness/fairfed/_server.py   |   4 +-
 declearn/fairness/fairgrad/_client.py  |   4 +-
 declearn/fairness/fairgrad/_server.py  |   4 +-
 declearn/fairness/monitor/_client.py   |   2 +-
 declearn/fairness/monitor/_server.py   |   6 +-
 declearn/main/_client.py               |   2 +-
 declearn/main/_server.py               |  23 +---
 test/main/test_main_client.py          |   8 +-
 13 files changed, 186 insertions(+), 66 deletions(-)

diff --git a/declearn/fairness/api/_client.py b/declearn/fairness/api/_client.py
index 6f187cc..3d08709 100644
--- a/declearn/fairness/api/_client.py
+++ b/declearn/fairness/api/_client.py
@@ -161,12 +161,36 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
         secagg:
             Optional SecAgg encryption controller.
         """
+        # Agree on a list of sensitive groups and share local sample counts.
+        await self.exchange_sensitive_groups_list_and_counts(netwk, secagg)
+        # Run additional algorithm-specific setup steps.
+        await self.finalize_fairness_setup(netwk, secagg)
+
+    async def exchange_sensitive_groups_list_and_counts(
+        self,
+        netwk: NetworkClient,
+        secagg: Optional[Encrypter],
+    ) -> None:
+        """Agree on a list of sensitive groups and share local sample counts.
+
+        This method performs the following routine:
+
+        - Send the list of local sensitive group definitions to the server.
+        - Await a unified list of sensitive groups in return.
+        - Assign the received list as `groups` attribute.
+        - Send (optionally-encrypted) group-wise sample counts to the server.
+
+        Parameters
+        ----------
+        netwk:
+            `NetworkClient` endpoint, connected to a server.
+        secagg:
+            Optional SecAgg encryption controller.
+        """
         # Share sensitive groups definitions and received an ordered list.
         self.groups = await self._exchange_sensitive_groups_list(netwk)
         # Send group-wise sample counts for the server to (secure-)aggregate.
         await self._send_sensitive_groups_counts(netwk, secagg)
-        # Run additional algorithm-specific setup steps.
-        await self.finalize_fairness_setup(netwk, secagg)
 
     async def _exchange_sensitive_groups_list(
         self,
@@ -220,7 +244,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
             Optional SecAgg encryption controller.
         """
 
-    async def fairness_round(
+    async def run_fairness_round(
         self,
         netwk: NetworkClient,
         query: FairnessQuery,
@@ -253,7 +277,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
             await netwk.send_message(Error(error))
             raise RuntimeError(error) from exc
         # Run additional algorithm-specific steps.
-        return await self.finalize_fairness_round(netwk, values, secagg)
+        return await self.finalize_fairness_round(netwk, secagg, values)
 
     async def _compute_and_share_fairness_measures(
         self,
@@ -374,8 +398,8 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
     async def finalize_fairness_round(
         self,
         netwk: NetworkClient,
-        values: Dict[str, Dict[Tuple[Any, ...], float]],
         secagg: Optional[Encrypter],
+        values: Dict[str, Dict[Tuple[Any, ...], float]],
     ) -> Dict[str, Union[float, np.ndarray]]:
         """Take actions to enforce fairness.
 
@@ -387,13 +411,13 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
         ----------
         netwk:
             NetworkClient endpoint instance, connected to a server.
+        secagg:
+            Optional SecAgg encryption controller.
         values:
             Nested dict of locally-computed group-wise metrics.
             This is the second set of `compute_fairness_measures` return
             values; when this method is called, the first has already
             been shared with the server for (secure-)aggregation.
-        secagg:
-            Optional SecAgg encryption controller.
 
         Returns
         -------
diff --git a/declearn/fairness/api/_server.py b/declearn/fairness/api/_server.py
index 33deb01..8ef8a54 100644
--- a/declearn/fairness/api/_server.py
+++ b/declearn/fairness/api/_server.py
@@ -26,8 +26,10 @@ from declearn.aggregator import Aggregator
 from declearn.communication.api import NetworkServer
 from declearn.communication.utils import verify_client_messages_validity
 from declearn.messaging import (
+    Error,
     FairnessCounts,
     FairnessGroups,
+    FairnessReply,
     FairnessSetupQuery,
     SerializedMessage,
 )
@@ -35,6 +37,7 @@ from declearn.secagg.api import Decrypter
 from declearn.secagg.messaging import (
     aggregate_secagg_messages,
     SecaggFairnessCounts,
+    SecaggFairnessReply,
 )
 from declearn.utils import create_types_registry, register_type
 
@@ -67,7 +70,7 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
     def __init__(
         self,
         f_type: str,
-        f_args: Optional[Dict[str, Any]],
+        f_args: Optional[Dict[str, Any]] = None,
     ) -> None:
         """Instantiate the server-side fairness controller.
 
@@ -82,6 +85,8 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
         self.f_args = f_args or {}
         self.groups = []  # type: List[Tuple[Any, ...]]
 
+    # Fairness Setup methods.
+
     async def setup_fairness(
         self,
         netwk: NetworkServer,
@@ -126,12 +131,14 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
         # Send a setup query to all clients.
         query = self.prepare_fairness_setup_query()
         await netwk.broadcast_message(query)
-        # Receive, aggregate, assign and send back sensitive group definitions.
-        self.groups = await self._exchange_sensitive_groups_list(netwk)
-        # Receive, (secure-)aggregate and return group-wise sample counts.
-        counts = await self._aggregate_sensitive_groups_counts(netwk, secagg)
+        # Agree on a list of sensitive groups and aggregate sample counts.
+        counts = await self.exchange_sensitive_groups_list_and_counts(
+            netwk, secagg
+        )
         # Run additional algorithm-specific setup steps.
-        return await self.finalize_fairness_setup(netwk, counts, aggregator)
+        return await self.finalize_fairness_setup(
+            netwk, secagg, counts, aggregator
+        )
 
     def prepare_fairness_setup_query(
         self,
@@ -149,6 +156,40 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
             params={"f_type": self.f_type, "f_args": self.f_args},
         )
 
+    async def exchange_sensitive_groups_list_and_counts(
+        self,
+        netwk: NetworkServer,
+        secagg: Optional[Decrypter],
+    ) -> List[int]:
+        """Agree on a list of sensitive groups and aggregate sample counts.
+
+        This method performs the following routine:
+
+        - Await `FairnessGroups` messages from clients with group definitions.
+        - Assign a sorted list of sensitive groups as `groups` attribute.
+        - Share that list with clients.
+        - Await possibly-encrypted group-wise sample counts from clients.
+        - (Secure-)Aggregate these sample counts and return them.
+
+        Parameters
+        ----------
+        netwk:
+            `NetworkServer` endpoint, through which a fairness setup query
+            was previously sent to all clients.
+        secagg:
+            Optional SecAgg decryption controller.
+
+        Returns
+        -------
+        counts:
+            List of group-wise total sample count across clients,
+            sorted based on the newly-assigned `self.groups`.
+        """
+        # Receive, aggregate, assign and send back sensitive group definitions.
+        self.groups = await self._exchange_sensitive_groups_list(netwk)
+        # Receive, (secure-)aggregate and return group-wise sample counts.
+        return await self._aggregate_sensitive_groups_counts(netwk, secagg)
+
     @staticmethod
     async def _exchange_sensitive_groups_list(
         netwk: NetworkServer,
@@ -213,6 +254,7 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
     async def finalize_fairness_setup(
         self,
         netwk: NetworkServer,
+        secagg: Optional[Decrypter],
         counts: List[int],
         aggregator: Aggregator,
     ) -> Aggregator:
@@ -238,13 +280,90 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
             or may not have been altered compared with the input one.
         """
 
+    # Fairness Round methods.
+
+    async def run_fairness_round(
+        self,
+        netwk: NetworkServer,
+        secagg: Optional[Decrypter],
+    ) -> Dict[str, Union[float, np.ndarray]]:
+        """Secure-aggregate and post-process fairness measures.
+
+        This method is to be run **after** having sent a `FairnessQuery`
+        to clients. It consists in receiving, (secure-)aggregating and
+        post-processing measures that clients produce as a reply to that
+        query. This may involve further algorithm-specific communications.
+
+        Parameters
+        ----------
+        netwk:
+            NetworkServer endpoint instance, to which clients are registered.
+        secagg:
+            Optional SecAgg decryption controller.
+
+        Returns
+        -------
+        metrics:
+            Fairness(-related) metrics computed as part of this routine,
+            as a dict mapping scalar or numpy array values with their name.
+        """
+        values = await self.receive_and_aggregate_fairness_measures(
+            netwk, secagg
+        )
+        return await self.finalize_fairness_round(netwk, secagg, values)
+
+    async def receive_and_aggregate_fairness_measures(
+        self,
+        netwk: NetworkServer,
+        secagg: Optional[Decrypter],
+    ) -> List[float]:
+        """Await and (secure-)aggregate client-wise fairness-related metrics.
+
+        This method is designed to be called after sending a `FairnessQuery`
+        to clients, and returns values that are yet to be parsed and used by
+        the algorithm-dependent `finalize_fairness_round` method.
+
+        Parameters
+        ----------
+        netwk:
+            NetworkServer endpoint instance, to which clients are registered.
+        secagg:
+            Optional SecAgg decryption controller.
+
+        Returns
+        -------
+        metrics:
+            List of sum-aggregated fairness-related metrics (as floats).
+            By default, these are group-wise accuracy values; this may
+            however be changed or expanded by algorithm-specific classes.
+        """
+        received = await netwk.wait_for_messages()
+        # Case when expecting cleartext values.
+        if secagg is None:
+            replies = await verify_client_messages_validity(
+                netwk, received, expected=FairnessReply
+            )
+            if len(set(len(r.values) for r in replies.values())) != 1:
+                error = "Clients sent fairness values of different lengths."
+                await netwk.broadcast_message(Error(error))
+                raise RuntimeError(error)
+            return [
+                sum(rval)
+                for rval in zip(*[reply.values for reply in replies.values()])
+            ]
+        # Case when expecting encrypted values.
+        secagg_replies = await verify_client_messages_validity(
+            netwk, received, expected=SecaggFairnessReply
+        )
+        agg_reply = aggregate_secagg_messages(secagg_replies, decrypter=secagg)
+        return agg_reply.values
+
     @abc.abstractmethod
     async def finalize_fairness_round(
         self,
-        round_i: int,
-        values: List[float],
         netwk: NetworkServer,
         secagg: Optional[Decrypter],
+        values: List[float],
     ) -> Dict[str, Union[float, np.ndarray]]:
         """Orchestrate a round of actions to enforce fairness.
 
@@ -254,21 +373,17 @@ class FairnessControllerServer(metaclass=abc.ABCMeta):
 
         Parameters
         ----------
-        round_i:
-            Index of the current round (reflecting that of an upcoming
-            training round).
-        values:
-            Aggregated metrics resulting from the fairness evaluation
-            run by clients at this round.
         netwk:
             NetworkServer endpoint instance, to which clients are registered.
         secagg:
             Optional SecAgg decryption controller.
+        values:
+            Aggregated metrics resulting from the fairness evaluation
+            run by clients at this round.
 
         Returns
         -------
         metrics:
-            Computed local fairness(-related) metrics computed as part
-            of this routine, as a dict mapping scalar or numpy array
-            values with their name.
+            Fairness(-related) metrics computed as part of this routine,
+            as a dict mapping scalar or numpy array values with their name.
         """
diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py
index ace949f..264e639 100644
--- a/declearn/fairness/fairbatch/_client.py
+++ b/declearn/fairness/fairbatch/_client.py
@@ -84,7 +84,7 @@ class FairbatchControllerClient(FairnessControllerClient):
             If the sampling pobabilities' update fails.
         """
         # Receive aggregated sensitive weights.
-        received = await netwk.check_message()
+        received = await netwk.recv_message()
         message = await verify_server_message_validity(
             netwk, received, expected=FairbatchSamplingProbas
         )
@@ -114,15 +114,15 @@ class FairbatchControllerClient(FairnessControllerClient):
         thresh: Optional[float] = None,
     ) -> List[MeanMetric]:
         loss = self.computer.setup_loss_metric(model=self.manager.model)
-        metrics = super().setup_fairness_metrics()
+        metrics = super().setup_fairness_metrics(thresh=thresh)
         metrics.append(loss)
         return metrics
 
     async def finalize_fairness_round(
         self,
         netwk: NetworkClient,
-        values: Dict[str, Dict[Tuple[Any, ...], float]],
         secagg: Optional[Encrypter],
+        values: Dict[str, Dict[Tuple[Any, ...], float]],
     ) -> Dict[str, Union[float, np.ndarray]]:
         # Await updated loss weights from the server.
         await self._update_fairbatch_sampling_probas(netwk)
diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py
index ecf5971..89719fa 100644
--- a/declearn/fairness/fairbatch/_server.py
+++ b/declearn/fairness/fairbatch/_server.py
@@ -107,6 +107,7 @@ class FairbatchControllerServer(FairnessControllerServer):
     async def finalize_fairness_setup(
         self,
         netwk: NetworkServer,
+        secagg: Optional[Decrypter],
         counts: List[int],
         aggregator: Aggregator,
     ) -> Aggregator:
@@ -150,10 +151,9 @@ class FairbatchControllerServer(FairnessControllerServer):
 
     async def finalize_fairness_round(
         self,
-        round_i: int,
-        values: List[float],
         netwk: NetworkServer,
         secagg: Optional[Decrypter],
+        values: List[float],
     ) -> Dict[str, Union[float, np.ndarray]]:
         # Unpack group-wise accuracy and loss values.
         accuracy = dict(zip(self.groups, values[: len(self.groups)]))
diff --git a/declearn/fairness/fairfed/_client.py b/declearn/fairness/fairfed/_client.py
index 24efbef..c8a03b2 100644
--- a/declearn/fairness/fairfed/_client.py
+++ b/declearn/fairness/fairfed/_client.py
@@ -106,8 +106,8 @@ class FairfedControllerClient(FairnessControllerClient):
     async def finalize_fairness_round(
         self,
         netwk: NetworkClient,
-        values: Dict[str, Dict[Tuple[Any, ...], float]],
         secagg: Optional[Encrypter],
+        values: Dict[str, Dict[Tuple[Any, ...], float]],
     ) -> Dict[str, Union[float, np.ndarray]]:
         # Await absolute mean fairness across all clients.
         received = await netwk.recv_message()
diff --git a/declearn/fairness/fairfed/_server.py b/declearn/fairness/fairfed/_server.py
index 24b5348..99f274b 100644
--- a/declearn/fairness/fairfed/_server.py
+++ b/declearn/fairness/fairfed/_server.py
@@ -113,6 +113,7 @@ class FairfedControllerServer(FairnessControllerServer):
     async def finalize_fairness_setup(
         self,
         netwk: NetworkServer,
+        secagg: Optional[Decrypter],
         counts: List[int],
         aggregator: Aggregator,
     ) -> Aggregator:
@@ -130,10 +131,9 @@ class FairfedControllerServer(FairnessControllerServer):
 
     async def finalize_fairness_round(
         self,
-        round_i: int,
-        values: List[float],
         netwk: NetworkServer,
         secagg: Optional[Decrypter],
+        values: List[float],
     ) -> Dict[str, Union[float, np.ndarray]]:
         # Unpack group-wise accuracy values and compute fairness ones.
         accuracy = dict(zip(self.groups, values))
diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py
index 4a75c06..aa127af 100644
--- a/declearn/fairness/fairgrad/_client.py
+++ b/declearn/fairness/fairgrad/_client.py
@@ -69,7 +69,7 @@ class FairgradControllerClient(FairnessControllerClient):
             If the weights' update fails.
         """
         # Receive aggregated sensitive weights.
-        received = await netwk.check_message()
+        received = await netwk.recv_message()
         message = await verify_server_message_validity(
             netwk, received, expected=FairgradWeights
         )
@@ -93,8 +93,8 @@ class FairgradControllerClient(FairnessControllerClient):
     async def finalize_fairness_round(
         self,
         netwk: NetworkClient,
-        values: Dict[str, Dict[Tuple[Any, ...], float]],
         secagg: Optional[Encrypter],
+        values: Dict[str, Dict[Tuple[Any, ...], float]],
     ) -> Dict[str, Union[float, np.ndarray]]:
         # Await updated loss weights from the server.
         await self._update_fairgrad_weights(netwk)
diff --git a/declearn/fairness/fairgrad/_server.py b/declearn/fairness/fairgrad/_server.py
index 53ef1b7..822b2bc 100644
--- a/declearn/fairness/fairgrad/_server.py
+++ b/declearn/fairness/fairgrad/_server.py
@@ -189,6 +189,7 @@ class FairgradControllerServer(FairnessControllerServer):
     async def finalize_fairness_setup(
         self,
         netwk: NetworkServer,
+        secagg: Optional[Decrypter],
         counts: List[int],
         aggregator: Aggregator,
     ) -> Aggregator:
@@ -230,10 +231,9 @@ class FairgradControllerServer(FairnessControllerServer):
 
     async def finalize_fairness_round(
         self,
-        round_i: int,
-        values: List[float],
         netwk: NetworkServer,
         secagg: Optional[Decrypter],
+        values: List[float],
     ) -> Dict[str, Union[float, np.ndarray]]:
         # Unpack group-wise accuracy metrics and update loss weights.
         accuracy = dict(zip(self.groups, values))
diff --git a/declearn/fairness/monitor/_client.py b/declearn/fairness/monitor/_client.py
index 5b41158..01aa352 100644
--- a/declearn/fairness/monitor/_client.py
+++ b/declearn/fairness/monitor/_client.py
@@ -45,8 +45,8 @@ class FairnessMonitorClient(FairnessControllerClient):
     async def finalize_fairness_round(
         self,
         netwk: NetworkClient,
-        values: Dict[str, Dict[Tuple[Any, ...], float]],
         secagg: Optional[Encrypter],
+        values: Dict[str, Dict[Tuple[Any, ...], float]],
     ) -> Dict[str, Union[float, np.ndarray]]:
         return {
             f"{metric}_{group}": value
diff --git a/declearn/fairness/monitor/_server.py b/declearn/fairness/monitor/_server.py
index d319a9c..12f2c20 100644
--- a/declearn/fairness/monitor/_server.py
+++ b/declearn/fairness/monitor/_server.py
@@ -42,7 +42,7 @@ class FairnessMonitorServer(FairnessControllerServer):
     def __init__(
         self,
         f_type: str,
-        f_args: Optional[Dict[str, Any]],
+        f_args: Optional[Dict[str, Any]] = None,
     ) -> None:
         super().__init__(f_type, f_args)
         # Assign a temporary fairness functions, replaced at setup time.
@@ -53,6 +53,7 @@ class FairnessMonitorServer(FairnessControllerServer):
     async def finalize_fairness_setup(
         self,
         netwk: NetworkServer,
+        secagg: Optional[Decrypter],
         counts: List[int],
         aggregator: Aggregator,
     ) -> Aggregator:
@@ -65,10 +66,9 @@ class FairnessMonitorServer(FairnessControllerServer):
 
     async def finalize_fairness_round(
         self,
-        round_i: int,
-        values: List[float],
         netwk: NetworkServer,
         secagg: Optional[Decrypter],
+        values: List[float],
     ) -> Dict[str, Union[float, np.ndarray]]:
         # Unpack group-wise accuracy metrics and compute fairness ones.
         accuracy = dict(zip(self.groups, values))
diff --git a/declearn/main/_client.py b/declearn/main/_client.py
index c1619e7..1b3f45e 100644
--- a/declearn/main/_client.py
+++ b/declearn/main/_client.py
@@ -660,7 +660,7 @@ class FederatedClient:
             await self.netwk.send_message(messaging.Error(error))
             return
         # Otherwise, run the controller's routine.
-        metrics = await self.fairness.fairness_round(
+        metrics = await self.fairness.run_fairness_round(
             netwk=self.netwk, query=query, secagg=self._encrypter
         )
         # Optionally save computed fairness metrics.
diff --git a/declearn/main/_server.py b/declearn/main/_server.py
index 11f157a..d0eecfc 100644
--- a/declearn/main/_server.py
+++ b/declearn/main/_server.py
@@ -565,27 +565,8 @@ class FederatedServer:
             weights=None,
         )
         await self._send_request_with_optional_weights(query, clients)
-        # Await and (secure-)aggregate) results.
-        self.logger.info("Awaiting clients' fairness measures.")
-        if self._decrypter is None:
-            replies = await self._collect_results(
-                clients, messaging.FairnessReply, "fairness round"
-            )
-            if len(set(len(r.values) for r in replies.values())) != 1:
-                error = "Clients sent fairness values of different lengths."
-                self.logger.error(error)
-                await self.netwk.broadcast_message(messaging.Error(error))
-                raise RuntimeError(error)
-            values = [sum(c_values) for c_values in zip(*replies.values())]
-        else:
-            secagg_replies = await self._collect_results(
-                clients, secagg_messaging.SecaggFairnessReply, "fairness round"
-            )
-            values = self._aggregate_secagg_replies(secagg_replies).values
-        # Have the fairness controller process results.
-        metrics = await self.fairness.finalize_fairness_round(
-            round_i=round_i,
-            values=values,
+        # Await, (secure-)aggregate and process fairness measures.
+        metrics = await self.fairness.run_fairness_round(
             netwk=self.netwk,
             secagg=self._decrypter,
         )
diff --git a/test/main/test_main_client.py b/test/main/test_main_client.py
index 391daec..7a591c7 100644
--- a/test/main/test_main_client.py
+++ b/test/main/test_main_client.py
@@ -1067,13 +1067,13 @@ class TestFederatedClientFairnessRound:
         # Call the 'fairness_round' routine and verify expected actions.
         request = messaging.FairnessQuery(round_i=1)
         await client.fairness_round(request)
-        fairness.fairness_round.assert_awaited_once_with(
+        fairness.run_fairness_round.assert_awaited_once_with(
             netwk=netwk, query=request, secagg=None
         )
         # Verify that when a checkpointer is set, it is used.
         if ckpt:
             client.ckptr.save_metrics.assert_called_once_with(  # type: ignore
-                metrics=fairness.fairness_round.return_value,
+                metrics=fairness.run_fairness_round.return_value,
                 prefix="fairness_metrics",
                 append=True,
                 timestamp="round_1",
@@ -1102,7 +1102,7 @@ class TestFederatedClientFairnessRound:
         # Call the 'fairness_round' routine and verify expected actions.
         request = messaging.FairnessQuery(round_i=1)
         await client.fairness_round(request)
-        fairness.fairness_round.assert_awaited_once_with(
+        fairness.run_fairness_round.assert_awaited_once_with(
             netwk=netwk,
             query=request,
             secagg=secagg.setup_encrypter.return_value,
@@ -1149,7 +1149,7 @@ class TestFederatedClientFairnessRound:
         netwk.send_message.assert_called_once()
         reply = netwk.send_message.call_args.args[0]
         assert isinstance(reply, messaging.Error)
-        fairness.fairness_round.assert_not_called()
+        fairness.run_fairness_round.assert_not_called()
 
 
 class TestFederatedClientMisc:
-- 
GitLab