From 3c1bfc1223c2b7b3ac6ea9ab0b57f7a2fd63d53f Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Mon, 3 Jun 2024 15:45:01 +0200
Subject: [PATCH] Update Fairness API to limit redundant computations.

- Have clients' measures-computing routine return both shareable
  and local-use values.
- This way, unscaled metrics that are to be post-processed and/or
  checkpointed can be kept as-is rather than scaled-then-descaled.
---
 declearn/fairness/api/_client.py       | 29 +++++++++++--------
 declearn/fairness/core/_accuracy.py    | 18 ++++++++++++
 declearn/fairness/core/_fair_func.py   |  1 +
 declearn/fairness/fairbatch/_client.py | 39 +++++++++++---------------
 declearn/fairness/fairgrad/_client.py  | 25 ++++++++---------
 5 files changed, 64 insertions(+), 48 deletions(-)

diff --git a/declearn/fairness/api/_client.py b/declearn/fairness/api/_client.py
index 7dd98ac..7b60c90 100644
--- a/declearn/fairness/api/_client.py
+++ b/declearn/fairness/api/_client.py
@@ -254,19 +254,20 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
         # Optionally update the wrapped model's weights.
         if query.weights is not None:
             self.manager.model.set_weights(query.weights, trainable=True)
-        # Compute, opt. encrypt and share fairness-related metrics.
-        values = self.compute_fairness_measures(
+        # Compute some fairness-related values, split between two sets.
+        share_values, local_values = self.compute_fairness_measures(
             query.batch_size, query.n_batch, query.thresh
         )
-        reply = FairnessReply(values=values)
+        # Share the first set of values for their (secure-)aggregation.
+        reply = FairnessReply(values=share_values)
         if secagg is None:
             await netwk.send_message(reply)
         else:
             await netwk.send_message(
                 SecaggFairnessReply.from_cleartext_message(reply, secagg)
             )
-        # Return computed values.
-        return values
+        # Return the second set of values.
+        return local_values
 
     @abc.abstractmethod
     def compute_fairness_measures(
@@ -274,7 +275,7 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
         batch_size: int,
         n_batch: Optional[int] = None,
         thresh: Optional[float] = None,
-    ) -> List[float]:
+    ) -> Tuple[List[float], List[float]]:
         """Compute fairness measures based on a received query.
 
         By default, compute and return group-wise accuracy metrics,
@@ -297,9 +298,13 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
 
         Returns
         -------
-        values:
-            Computed values, as a deterministic-length ordered list
-            of float values.
+        share_values:
+            Values that are to be shared with the orchestrating server,
+            as a deterministic-length list of float values.
+        local_values:
+            Values that are to be used in local post-processing steps.
+            This may be a reference to `share_values`, but is typically
+            designed to contain unscaled measures to checkpoint.
         """
 
     @abc.abstractmethod
@@ -320,8 +325,10 @@ class FairnessControllerClient(metaclass=abc.ABCMeta):
         netwk:
             NetworkClient endpoint instance, connected to a server.
         values:
-            List of locally-computed evaluation metrics, already shared
-            with the server for their (secure-)aggregation.
+            List of locally-computed evaluation metrics.
+            This is the second set of `compute_fairness_measures` return
+            values; when this method is called, the first has already
+            been shared with the server for (secure-)aggregation.
         secagg:
             Optional SecAgg encryption controller.
 
diff --git a/declearn/fairness/core/_accuracy.py b/declearn/fairness/core/_accuracy.py
index 0f2c32c..e842ef3 100644
--- a/declearn/fairness/core/_accuracy.py
+++ b/declearn/fairness/core/_accuracy.py
@@ -310,3 +310,21 @@ class FairnessAccuracyComputer:
             g_losses[group] = float(results[ModelLoss.name])
         # Return the pair of dicts storing results.
         return accuracy, g_losses
+
+    def scale_metrics_by_sample_counts(
+        self,
+        metrics: Dict[Tuple[Any, ...], float],
+    ) -> Dict[Tuple[Any, ...], float]:
+        """Scale a dict of computed group-wise metrics by sample counts.
+
+        Parameters
+        ----------
+        metrics:
+            Pre-computed raw metrics, as a `{group_k: score_k}` dict.
+
+        Returns
+        -------
+        metrics:
+            Scaled matrics, as a `{group_k: n_k * score_k}` dict.
+        """
+        return {key: val * self.counts[key] for key, val in metrics.items()}
diff --git a/declearn/fairness/core/_fair_func.py b/declearn/fairness/core/_fair_func.py
index 6fab534..0a042fc 100644
--- a/declearn/fairness/core/_fair_func.py
+++ b/declearn/fairness/core/_fair_func.py
@@ -220,6 +220,7 @@ class FairnessFunction(metaclass=abc.ABCMeta):
             Values' interpretation depend on the implemented group-fairness
             definition, but overall the fairer the accuracy towards a group,
             the closer the metric is to zero.
+
         Raises
         ------
         KeyError
diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py
index 41f4460..378aafd 100644
--- a/declearn/fairness/fairbatch/_client.py
+++ b/declearn/fairness/fairbatch/_client.py
@@ -17,7 +17,7 @@
 
 """Client-side Fed-FairBatch controller."""
 
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 
@@ -126,7 +126,7 @@ class FairbatchControllerClient(FairnessControllerClient):
         batch_size: int,
         n_batch: Optional[int] = None,
         thresh: Optional[float] = None,
-    ) -> List[float]:
+    ) -> Tuple[List[float], List[float]]:
         # Compute group-wise accuracy scores and loss values.
         accuracy, loss = self.computer.compute_groupwise_accuracy_and_loss(
             model=self.manager.model,
@@ -134,18 +134,18 @@ class FairbatchControllerClient(FairnessControllerClient):
             n_batch=n_batch,
             thresh=thresh,
         )
-        # Multiply these values by sample counts.
-        accuracy = {
-            key: val * self.computer.counts[key]
-            for key, val in accuracy.items()
-        }
-        loss = {
-            key: val * self.computer.counts[key] for key, val in loss.items()
-        }
-        # Return shareable group-wise values, ordered and filled out.
-        return [accuracy.get(group, 0.0) for group in self.groups] + [
-            loss.get(group, 0.0) for group in self.groups
+        # Flatten local values for post-processing and checkpointing.
+        local_values = list(accuracy.values()) + list(loss.values())
+        # Scale local values by sample counts for their aggregation.
+        accuracy = self.computer.scale_metrics_by_sample_counts(accuracy)
+        loss = self.computer.scale_metrics_by_sample_counts(loss)
+        # Flatten shareable values, ordered and filled-out.
+        share_values = [
+            *[accuracy.get(group, 0.0) for group in self.groups],
+            *[loss.get(group, 0.0) for group in self.groups],
         ]
+        # Return both sets of values.
+        return share_values, local_values
 
     async def finalize_fairness_round(
         self,
@@ -156,16 +156,9 @@ class FairbatchControllerClient(FairnessControllerClient):
         # Await updated loss weights from the server.
         await self._update_fairbatch_sampling_probas(netwk)
         # Recover raw accuracy and loss values for groups with local samples.
-        accuracy = {
-            key: val / self.computer.counts[key]
-            for key, val in zip(self.groups, values[: len(self.groups)])
-            if key in self.computer.counts
-        }
-        loss = {
-            key: val / self.computer.counts[key]
-            for key, val in zip(self.groups, values[len(self.groups) :])
-            if key in self.computer.counts
-        }
+        groups = list(self.computer.g_data)
+        accuracy = dict(zip(groups, values[: len(groups)]))
+        loss = dict(zip(groups, values[len(groups) :]))
         # Compute local fairness measures.
         fairness = self.fairness_function.compute_from_group_accuracy(accuracy)
         f_type = self.fairness_function.f_type
diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py
index 3e21b27..11308d0 100644
--- a/declearn/fairness/fairgrad/_client.py
+++ b/declearn/fairness/fairgrad/_client.py
@@ -17,7 +17,7 @@
 
 """Client-side Fed-FairGrad controller."""
 
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 
@@ -116,7 +116,7 @@ class FairgradControllerClient(FairnessControllerClient):
         batch_size: int,
         n_batch: Optional[int] = None,
         thresh: Optional[float] = None,
-    ) -> List[float]:
+    ) -> Tuple[List[float], List[float]]:
         # Compute group-wise accuracy scores.
         accuracy = self.computer.compute_groupwise_accuracy(
             model=self.manager.model,
@@ -124,13 +124,14 @@ class FairgradControllerClient(FairnessControllerClient):
             n_batch=n_batch,
             thresh=thresh,
         )
-        # Multiply these scores by sample counts.
-        accuracy = {
-            key: val * self.computer.counts[key]
-            for key, val in accuracy.items()
-        }
-        # Return shareable group-wise values, ordered and filled out.
-        return [accuracy.get(group, 0.0) for group in self.groups]
+        # Flatten local values for post-processing and checkpointing.
+        local_values = list(accuracy.values())
+        # Scale local values by sample counts for their aggregation.
+        accuracy = self.computer.scale_metrics_by_sample_counts(accuracy)
+        # Flatten shareable values, ordered and filled-out.
+        share_values = [accuracy.get(group, 0.0) for group in self.groups]
+        # Return both sets of values.
+        return share_values, local_values
 
     async def finalize_fairness_round(
         self,
@@ -141,11 +142,7 @@ class FairgradControllerClient(FairnessControllerClient):
         # Await updated loss weights from the server.
         await self._update_fairgrad_weights(netwk)
         # Recover raw accuracy scores for groups with local samples.
-        accuracy = {
-            key: val / self.computer.counts[key]
-            for key, val in zip(self.groups, values)
-            if key in self.computer.counts
-        }
+        accuracy = dict(zip(self.computer.g_data, values))
         # Compute local fairness measures.
         fairness = self.fairness_function.compute_from_group_accuracy(accuracy)
         f_type = self.fairness_function.f_type
-- 
GitLab