diff --git a/declearn/fairness/__init__.py b/declearn/fairness/__init__.py index eb491f872086dc4ab97d01fd0dae77a909053537..dc283d7d4669c8178c69440c7fdf5117db5d99eb 100644 --- a/declearn/fairness/__init__.py +++ b/declearn/fairness/__init__.py @@ -72,6 +72,9 @@ Algorithms submodules * [monitor][declearn.fairness.monitor]: Fairness-monitoring controllers, that leave training unaltered. +Note that the controllers implemented under these submodules +are type-registered under the submodule's name. + References ---------- diff --git a/declearn/fairness/fairbatch/_client.py b/declearn/fairness/fairbatch/_client.py index dd91840b5f2bca28f16fa4346aa6e4fa2795d569..9f7891615bda4cc5dac14fea3d078e864a807f46 100644 --- a/declearn/fairness/fairbatch/_client.py +++ b/declearn/fairness/fairbatch/_client.py @@ -46,7 +46,7 @@ __all__ = [ class FairbatchControllerClient(FairnessControllerClient): """Client-side controller to implement Fed-FairBatch or FedFB.""" - algorithm = "fedfairbatch" + algorithm = "fairbatch" def __init__( self, diff --git a/declearn/fairness/fairbatch/_server.py b/declearn/fairness/fairbatch/_server.py index 603d314ccfe742b6c84df0e1d25f85d714d03778..80f88aa9a186ac6ab004a7a744d8100a5b690761 100644 --- a/declearn/fairness/fairbatch/_server.py +++ b/declearn/fairness/fairbatch/_server.py @@ -73,7 +73,7 @@ class FairbatchControllerServer(FairnessControllerServer): https://arxiv.org/abs/2110.15545 """ - algorithm = "fedfairbatch" + algorithm = "fairbatch" def __init__( self, diff --git a/declearn/fairness/fairgrad/_client.py b/declearn/fairness/fairgrad/_client.py index aa127af28f27d72d8e844791f48ad9a879ecb806..9cfd150721c023843b43cfb69148173389863fc2 100644 --- a/declearn/fairness/fairgrad/_client.py +++ b/declearn/fairness/fairgrad/_client.py @@ -40,7 +40,7 @@ __all__ = [ class FairgradControllerClient(FairnessControllerClient): """Client-side controller to implement Fed-FairGrad.""" - algorithm = "fedfairgrad" + algorithm = "fairgrad" async def finalize_fairness_setup( self, diff --git a/declearn/fairness/fairgrad/_server.py b/declearn/fairness/fairgrad/_server.py index 8364e28ba72e9812c723e868cca634cab9c2a351..b4e4efcdc8a7a65df46847dea170cc3c8b5d18d3 100644 --- a/declearn/fairness/fairgrad/_server.py +++ b/declearn/fairness/fairgrad/_server.py @@ -176,7 +176,7 @@ class FairgradControllerServer(FairnessControllerServer): https://openreview.net/forum?id=0f8tU3QwWD """ - algorithm = "fedfairgrad" + algorithm = "fairgrad" def __init__( self, diff --git a/test/main/test_config_optim.py b/test/main/test_config_optim.py index 8c030a4b3864a7b54b0da5ead80fcabe9e2e5e5f..0fb4128af361f0bfd2dc24a209e0e3c170561fd2 100644 --- a/test/main/test_config_optim.py +++ b/test/main/test_config_optim.py @@ -190,7 +190,7 @@ class TestFLOptimConfig: """Test parsing 'fairness' from a dict.""" field = FIELDS["fairness"] config = { - "algorithm": "fedfairgrad", + "algorithm": "fairgrad", "f_type": "demographic_parity", "eta": 0.1, "eps": 0.0, @@ -204,7 +204,7 @@ class TestFLOptimConfig: def test_parse_fairness_dict_error(self) -> None: """Test parsing 'fairness' from an invalid dict.""" field = FIELDS["fairness"] - config = {"algorithm": "fedfairgrad"} # missing f_type choice + config = {"algorithm": "fairgrad"} # missing f_type choice with pytest.raises(TypeError): FLOptimConfig.parse_fairness(field, config) @@ -230,7 +230,7 @@ class TestFLOptimConfig: lrate = 1.0 modules = [["adam", {beta_1=0.8, beta_2=0.9}]] [optim.fairness] - algorithm = "fedfairgrad" + algorithm = "fairgrad" f_type = "equalized_odds" eta = 0.1 eps = 0.0