diff --git a/declearn/fairness/__init__.py b/declearn/fairness/__init__.py index 929b2e82f9a92e87f7ce0b0dd46d2d64f68a91c9..eb491f872086dc4ab97d01fd0dae77a909053537 100644 --- a/declearn/fairness/__init__.py +++ b/declearn/fairness/__init__.py @@ -88,7 +88,7 @@ References Improving Fairness via Federated Learning. https://arxiv.org/abs/2110.15545 - [4] - Eszzeldin et al. (2021). + Ezzeldin et al. (2021). FairFed: Enabling Group Fairness in Federated Learning https://arxiv.org/abs/2110.00857 """ diff --git a/declearn/fairness/core/__init__.py b/declearn/fairness/core/__init__.py index 3009a7c5e1ed4fc2f94c93d3bd025e4a5694eb8c..7ee5702c240d64820970c234438fd77e850ed676 100644 --- a/declearn/fairness/core/__init__.py +++ b/declearn/fairness/core/__init__.py @@ -40,6 +40,11 @@ Concrete implementations of various fairness functions: Equalized Odds group-fairness function. Abstraction and generic constructor may be found in [declearn.fairness.api][]. +An additional util may be used to list available functions, either declared +here or by third-party and end-user code: + +* [list_fairness_functions][declearn.fairness.core.list_fairness_functions]: + Return a mapping of registered FairnessFunction subclasses. """ from ._functions import ( @@ -47,5 +52,6 @@ from ._functions import ( DemographicParityFunction, EqualityOfOpportunityFunction, EqualizedOddsFunction, + list_fairness_functions, ) from ._inmemory import FairnessInMemoryDataset diff --git a/declearn/fairness/core/_functions.py b/declearn/fairness/core/_functions.py index 883f4139626771229454de650b6a4666e888ff3b..6d4dbc735892c311a98a6c85a2088bd97254618f 100644 --- a/declearn/fairness/core/_functions.py +++ b/declearn/fairness/core/_functions.py @@ -17,20 +17,47 @@ """Concrete implementations of various group-fairness functions.""" -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Type, Union import numpy as np from declearn.fairness.api import FairnessFunction +from declearn.utils import access_types_mapping __all__ = ( "AccuracyParityFunction", "DemographicParityFunction", "EqualityOfOpportunityFunction", "EqualizedOddsFunction", + "list_fairness_functions", ) +def list_fairness_functions() -> Dict[str, Type[FairnessFunction]]: + """Return a mapping of registered FairnessFunction subclasses. + + This function aims at making it easy for end-users to list and access + all available FairnessFunction classes at any given time. The returned + dict uses unique identifier keys, which may be used to use the associated + function within a [declearn.fairness.api.FairnessControllerServer][]. + + Note that the mapping will include all declearn-provided functions, + but also registered functions provided by user or third-party code. + + See also + -------- + * [declearn.fairness.api.FairnessFunction][]: + API-defining abstract base class for the FairnessFunction classes. + + Returns + ------- + mapping: + Dictionary mapping unique str identifiers to `FairnessFunction` + class constructors. + """ + return access_types_mapping("FairnessFunction") + + class AccuracyParityFunction(FairnessFunction): """Accuracy Parity group-fairness function. @@ -120,7 +147,7 @@ class DemographicParityFunction(FairnessFunction): evaluated classifier. In other words, Demographic Parity is achieved when the probability to - predict any given label is indenpendent from the sensitive attribute(s) + predict any given label is independent from the sensitive attribute(s) (regardless of whether that label is accurate or not). Formula diff --git a/test/fairness/test_fairness_functions.py b/test/fairness/test_fairness_functions.py index a0d3ee0dc6f68472f61d71ad0f9847b82d89da77..a40746d33b63559886aa7867bec20ef3b95f1cf9 100644 --- a/test/fairness/test_fairness_functions.py +++ b/test/fairness/test_fairness_functions.py @@ -32,6 +32,7 @@ from declearn.fairness.core import ( DemographicParityFunction, EqualityOfOpportunityFunction, EqualizedOddsFunction, + list_fairness_functions, ) from declearn.test_utils import assert_dict_equal @@ -297,3 +298,20 @@ class TestEqualityOfOpportunity(TestEqualizedOddsFunction): counts=self.counts, target="wrong-type", # type: ignore ) + + +def test_list_fairness_functions() -> None: + """Test 'declearn.fairness.core.list_fairness_functions'.""" + mapping = list_fairness_functions() + assert isinstance(mapping, dict) + assert all( + isinstance(key, str) and issubclass(val, FairnessFunction) + for key, val in mapping.items() + ) + for cls in ( + AccuracyParityFunction, + DemographicParityFunction, + EqualityOfOpportunityFunction, + EqualizedOddsFunction, + ): + assert mapping.get(cls.f_type) is cls # type: ignore