From b39b232650b1f128d696c4839e7abc94328d4059 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Tue, 23 Jul 2024 11:30:46 +0200
Subject: [PATCH] Change fairness algorithms registration names to match
 submodule names.

---
 declearn/fairness/__init__.py          | 3 +++
 declearn/fairness/fairbatch/_client.py | 2 +-
 declearn/fairness/fairbatch/_server.py | 2 +-
 declearn/fairness/fairgrad/_client.py  | 2 +-
 declearn/fairness/fairgrad/_server.py  | 2 +-
 test/main/test_config_optim.py         | 6 +++---
 6 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/declearn/fairness/__init__.py b/declearn/fairness/__init__.py
index eb491f8..dc283d7 100644
--- a/declearn/fairness/__init__.py
+++ b/declearn/fairness/__init__.py
@@ -72,6 +72,9 @@ Algorithms submodules
 * [monitor][declearn.fairness.monitor]:
     Fairness-monitoring controllers, that leave training unaltered.
 
+Note that the controllers implemented under these submodules
+are type-registered under the submodule's name.
+
 References
 ----------
 
diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py
index dd91840..9f78916 100644
--- a/declearn/fairness/fairbatch/_client.py
+++ b/declearn/fairness/fairbatch/_client.py
@@ -46,7 +46,7 @@ __all__ = [
 class FairbatchControllerClient(FairnessControllerClient):
     """Client-side controller to implement Fed-FairBatch or FedFB."""
 
-    algorithm = "fedfairbatch"
+    algorithm = "fairbatch"
 
     def __init__(
         self,
diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py
index 603d314..80f88aa 100644
--- a/declearn/fairness/fairbatch/_server.py
+++ b/declearn/fairness/fairbatch/_server.py
@@ -73,7 +73,7 @@ class FairbatchControllerServer(FairnessControllerServer):
         https://arxiv.org/abs/2110.15545
     """
 
-    algorithm = "fedfairbatch"
+    algorithm = "fairbatch"
 
     def __init__(
         self,
diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py
index aa127af..9cfd150 100644
--- a/declearn/fairness/fairgrad/_client.py
+++ b/declearn/fairness/fairgrad/_client.py
@@ -40,7 +40,7 @@ __all__ = [
 class FairgradControllerClient(FairnessControllerClient):
     """Client-side controller to implement Fed-FairGrad."""
 
-    algorithm = "fedfairgrad"
+    algorithm = "fairgrad"
 
     async def finalize_fairness_setup(
         self,
diff --git a/declearn/fairness/fairgrad/_server.py b/declearn/fairness/fairgrad/_server.py
index 8364e28..b4e4efc 100644
--- a/declearn/fairness/fairgrad/_server.py
+++ b/declearn/fairness/fairgrad/_server.py
@@ -176,7 +176,7 @@ class FairgradControllerServer(FairnessControllerServer):
         https://openreview.net/forum?id=0f8tU3QwWD
     """
 
-    algorithm = "fedfairgrad"
+    algorithm = "fairgrad"
 
     def __init__(
         self,
diff --git a/test/main/test_config_optim.py b/test/main/test_config_optim.py
index 8c030a4..0fb4128 100644
--- a/test/main/test_config_optim.py
+++ b/test/main/test_config_optim.py
@@ -190,7 +190,7 @@ class TestFLOptimConfig:
         """Test parsing 'fairness' from a dict."""
         field = FIELDS["fairness"]
         config = {
-            "algorithm": "fedfairgrad",
+            "algorithm": "fairgrad",
             "f_type": "demographic_parity",
             "eta": 0.1,
             "eps": 0.0,
@@ -204,7 +204,7 @@ class TestFLOptimConfig:
     def test_parse_fairness_dict_error(self) -> None:
         """Test parsing 'fairness' from an invalid dict."""
         field = FIELDS["fairness"]
-        config = {"algorithm": "fedfairgrad"}  # missing f_type choice
+        config = {"algorithm": "fairgrad"}  # missing f_type choice
         with pytest.raises(TypeError):
             FLOptimConfig.parse_fairness(field, config)
 
@@ -230,7 +230,7 @@ class TestFLOptimConfig:
             lrate = 1.0
             modules = [["adam", {beta_1=0.8, beta_2=0.9}]]
         [optim.fairness]
-            algorithm = "fedfairgrad"
+            algorithm = "fairgrad"
             f_type = "equalized_odds"
             eta = 0.1
             eps = 0.0
-- 
GitLab