From b39b232650b1f128d696c4839e7abc94328d4059 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Tue, 23 Jul 2024 11:30:46 +0200 Subject: [PATCH] Change fairness algorithms registration names to match submodule names. --- declearn/fairness/__init__.py | 3 +++ declearn/fairness/fairbatch/_client.py | 2 +- declearn/fairness/fairbatch/_server.py | 2 +- declearn/fairness/fairgrad/_client.py | 2 +- declearn/fairness/fairgrad/_server.py | 2 +- test/main/test_config_optim.py | 6 +++--- 6 files changed, 10 insertions(+), 7 deletions(-) diff --git a/declearn/fairness/__init__.py b/declearn/fairness/__init__.py index eb491f8..dc283d7 100644 --- a/declearn/fairness/__init__.py +++ b/declearn/fairness/__init__.py @@ -72,6 +72,9 @@ Algorithms submodules * [monitor][declearn.fairness.monitor]: Fairness-monitoring controllers, that leave training unaltered. +Note that the controllers implemented under these submodules +are type-registered under the submodule's name. + References ---------- diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py index dd91840..9f78916 100644 --- a/declearn/fairness/fairbatch/_client.py +++ b/declearn/fairness/fairbatch/_client.py @@ -46,7 +46,7 @@ __all__ = [ class FairbatchControllerClient(FairnessControllerClient): """Client-side controller to implement Fed-FairBatch or FedFB.""" - algorithm = "fedfairbatch" + algorithm = "fairbatch" def __init__( self, diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py index 603d314..80f88aa 100644 --- a/declearn/fairness/fairbatch/_server.py +++ b/declearn/fairness/fairbatch/_server.py @@ -73,7 +73,7 @@ class FairbatchControllerServer(FairnessControllerServer): https://arxiv.org/abs/2110.15545 """ - algorithm = "fedfairbatch" + algorithm = "fairbatch" def __init__( self, diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py index aa127af..9cfd150 100644 --- a/declearn/fairness/fairgrad/_client.py +++ b/declearn/fairness/fairgrad/_client.py @@ -40,7 +40,7 @@ __all__ = [ class FairgradControllerClient(FairnessControllerClient): """Client-side controller to implement Fed-FairGrad.""" - algorithm = "fedfairgrad" + algorithm = "fairgrad" async def finalize_fairness_setup( self, diff --git a/declearn/fairness/fairgrad/_server.py b/declearn/fairness/fairgrad/_server.py index 8364e28..b4e4efc 100644 --- a/declearn/fairness/fairgrad/_server.py +++ b/declearn/fairness/fairgrad/_server.py @@ -176,7 +176,7 @@ class FairgradControllerServer(FairnessControllerServer): https://openreview.net/forum?id=0f8tU3QwWD """ - algorithm = "fedfairgrad" + algorithm = "fairgrad" def __init__( self, diff --git a/test/main/test_config_optim.py b/test/main/test_config_optim.py index 8c030a4..0fb4128 100644 --- a/test/main/test_config_optim.py +++ b/test/main/test_config_optim.py @@ -190,7 +190,7 @@ class TestFLOptimConfig: """Test parsing 'fairness' from a dict.""" field = FIELDS["fairness"] config = { - "algorithm": "fedfairgrad", + "algorithm": "fairgrad", "f_type": "demographic_parity", "eta": 0.1, "eps": 0.0, @@ -204,7 +204,7 @@ class TestFLOptimConfig: def test_parse_fairness_dict_error(self) -> None: """Test parsing 'fairness' from an invalid dict.""" field = FIELDS["fairness"] - config = {"algorithm": "fedfairgrad"} # missing f_type choice + config = {"algorithm": "fairgrad"} # missing f_type choice with pytest.raises(TypeError): FLOptimConfig.parse_fairness(field, config) @@ -230,7 +230,7 @@ class TestFLOptimConfig: lrate = 1.0 modules = [["adam", {beta_1=0.8, beta_2=0.9}]] [optim.fairness] - algorithm = "fedfairgrad" + algorithm = "fairgrad" f_type = "equalized_odds" eta = 0.1 eps = 0.0 -- GitLab