diff --git a/declearn/aggregator/__init__.py b/declearn/aggregator/__init__.py index 76c285654fb199870664407309608b7e8a3e9574..1ac99d21dd7d75d4198cdb10de71cd61cce79f6e 100644 --- a/declearn/aggregator/__init__.py +++ b/declearn/aggregator/__init__.py @@ -17,22 +17,32 @@ """Model updates aggregating API and implementations. -An Aggregator is typically meant to be used on a round-wise basis by +An `Aggregator` is typically meant to be used on a round-wise basis by the orchestrating server of a centralized federated learning process, -to aggregate the client-wise model updated into a Vector that may then -be used as "gradients" by the server's Optimizer to update the global +to aggregate the client-wise model updated into a `Vector` that may then +be used as "gradients" by the server's `Optimizer` to update the global model. This declearn submodule provides with: +API tools +--------- + * [Aggregator][declearn.aggregator.Aggregator]: - abstract class defining an API for Vector aggregation + Abstract base class defining an API for Vector aggregation. +* [list_aggregators][declearn.aggregator.list_aggregators]: + Return a mapping of registered Aggregator subclasses. + + +Concrete classes +---------------- + * [AveragingAggregator][declearn.aggregator.AveragingAggregator]: - average-based-aggregation Aggregator subclass + Average-based-aggregation Aggregator subclass. * [GradientMaskedAveraging][declearn.aggregator.GradientMaskedAveraging]: - gradient Masked Averaging Aggregator subclass + Gradient Masked Averaging Aggregator subclass. """ -from ._api import Aggregator +from ._api import Aggregator, list_aggregators from ._base import AveragingAggregator from ._gma import GradientMaskedAveraging diff --git a/declearn/aggregator/_api.py b/declearn/aggregator/_api.py index 3711dda11e69d6d8771970ce1bc9f344cbb98cc4..778b1d3e0c92a7e0c0baedea85e5b35ae6ace0a7 100644 --- a/declearn/aggregator/_api.py +++ b/declearn/aggregator/_api.py @@ -18,18 +18,25 @@ """Model updates aggregation API.""" from abc import ABCMeta, abstractmethod -from typing import Any, ClassVar, Dict +from typing import Any, ClassVar, Dict, Type, TypeVar from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.model.api import Vector -from declearn.utils import create_types_registry, register_type +from declearn.utils import ( + access_types_mapping, + create_types_registry, + register_type, +) __all__ = [ "Aggregator", ] +T = TypeVar("T") + + @create_types_registry class Aggregator(metaclass=ABCMeta): """Abstract class defining an API for Vector aggregation. @@ -90,9 +97,9 @@ class Aggregator(metaclass=ABCMeta): @abstractmethod def aggregate( self, - updates: Dict[str, Vector], + updates: Dict[str, Vector[T]], n_steps: Dict[str, int], # revise: abstract~generalize kwargs use - ) -> Vector: + ) -> Vector[T]: """Aggregate input vectors into a single one. Parameters @@ -109,6 +116,11 @@ class Aggregator(metaclass=ABCMeta): gradients: Vector Aggregated updates, as a Vector - treated as gradients by the server-side optimizer. + + Raises + ------ + TypeError + If the input `updates` are an empty dict. """ def get_config( @@ -124,3 +136,29 @@ class Aggregator(metaclass=ABCMeta): ) -> Self: """Instantiate an Aggregator from its configuration dict.""" return cls(**config) + + +def list_aggregators() -> Dict[str, Type[Aggregator]]: + """Return a mapping of registered Aggregator subclasses. + + This function aims at making it easy for end-users to list and access + all available Aggregator classes at any given time. The returned dict + uses unique identifier keys, which may be used to specify the desired + algorithm as part of a federated learning process without going through + the fuss of importing and instantiating it manually. + + Note that the mapping will include all declearn-provided plug-ins, + but also registered plug-ins provided by user or third-party code. + + See also + -------- + * [declearn.aggregator.Aggregator][]: + API-defining abstract base class for the aggregation algorithms. + + Returns + ------- + mapping: + Dictionary mapping unique str identifiers to `Aggregator` class + constructors. + """ + return access_types_mapping("Aggregator") diff --git a/declearn/aggregator/_base.py b/declearn/aggregator/_base.py index 460f312358be6e7076e9114a5da779770c567415..1adebe0537421e83c4854782398df8c43b044e20 100644 --- a/declearn/aggregator/_base.py +++ b/declearn/aggregator/_base.py @@ -19,7 +19,6 @@ from typing import Any, ClassVar, Dict, Optional -from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.aggregator._api import Aggregator from declearn.model.api import Vector @@ -76,13 +75,6 @@ class AveragingAggregator(Aggregator): "client_weights": self.client_weights, } - @classmethod - def from_config( - cls, - config: Dict[str, Any], - ) -> Self: - return cls(**config) - def aggregate( self, updates: Dict[str, Vector], diff --git a/declearn/main/config/_strategy.py b/declearn/main/config/_strategy.py index 2fcd058e6eb8276ff16f26d786790ffc44a309de..333072a74f9d122aa4d76df2fae8347c4ebad5cd 100644 --- a/declearn/main/config/_strategy.py +++ b/declearn/main/config/_strategy.py @@ -22,7 +22,7 @@ import functools from typing import Any, Dict, Union -from declearn.aggregator import Aggregator +from declearn.aggregator import Aggregator, AveragingAggregator from declearn.optimizer import Optimizer from declearn.utils import TomlConfig, access_registered, deserialize_object @@ -95,7 +95,9 @@ class FLOptimConfig(TomlConfig): server_opt: Optimizer = dataclasses.field( default_factory=functools.partial(Optimizer, lrate=1.0) ) - aggregator: Aggregator = dataclasses.field(default_factory=Aggregator) + aggregator: Aggregator = dataclasses.field( + default_factory=AveragingAggregator + ) @classmethod def parse_client_opt( @@ -147,7 +149,7 @@ class FLOptimConfig(TomlConfig): - (opt.) group: str used to retrieve the registered class - (opt.) config: dict specifying kwargs for the constructor - any other field will be added to the `config` kwargs dict - - as None (or missing kwarg), using default AverageAggregator() + - as None (or missing kwarg), using default AveragingAggregator() """ # Case when using the default value: delegate to the default parser. if inputs is None: diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py index 583b2893773c81011b32d9dd6d4dea6b91bf0aef..e74102bbdfdc1c174f633c08eaf6ea23b62e949d 100644 --- a/declearn/model/api/_model.py +++ b/declearn/model/api/_model.py @@ -18,7 +18,7 @@ """Model abstraction API.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional, Set, Tuple +from typing import Any, Dict, Generic, Optional, Set, Tuple, TypeVar import numpy as np from typing_extensions import Self # future: import from typing (py >=3.11) @@ -33,8 +33,12 @@ __all__ = [ ] +VectorT = TypeVar("VectorT", bound=Vector) +"""Type-annotation for the Vector subclass proper to a given Model.""" + + @create_types_registry -class Model(metaclass=ABCMeta): +class Model(Generic[VectorT], metaclass=ABCMeta): """Abstract class defining an API to manipulate a ML model. A 'Model' is an abstraction that defines a generic interface @@ -59,6 +63,21 @@ class Model(metaclass=ABCMeta): """Instantiate a Model interface wrapping a 'model' object.""" self._model = model + def get_wrapped_model(self) -> Any: + """Getter to access the wrapped framework-specific model object. + + This getter should be used sparingly, so as to avoid undesirable + side effects. In particular, it should not be used in declearn + backend code (but may be in examples or tests), as it is merely + a way for end-users to access the wrapped model after training. + + Returns + ------- + model: + Wrapped model, of (framework/Model-subclass)-specific type. + """ + return self._model + @property @abstractmethod def device_policy( @@ -119,7 +138,7 @@ class Model(metaclass=ABCMeta): def get_weights( self, trainable: bool = False, - ) -> Vector: + ) -> VectorT: """Return the model's weights, optionally excluding frozen ones. Parameters @@ -140,7 +159,7 @@ class Model(metaclass=ABCMeta): @abstractmethod def set_weights( self, - weights: Vector, + weights: VectorT, trainable: bool = False, ) -> None: """Assign values to the model's weights. @@ -176,7 +195,7 @@ class Model(metaclass=ABCMeta): self, batch: Batch, max_norm: Optional[float] = None, - ) -> Vector: + ) -> VectorT: """Compute and return gradients computed over a given data batch. Compute the average gradients of the model's loss with respect @@ -204,7 +223,7 @@ class Model(metaclass=ABCMeta): @abstractmethod def apply_updates( self, - updates: Vector, + updates: VectorT, ) -> None: """Apply updates to the model's weights.""" diff --git a/declearn/model/api/_vector.py b/declearn/model/api/_vector.py index 294addd2db35007fb9fec270cf09168a6eb0c65e..f407dc3b3c609509b40c9a4c3ee17c9b3af40ef0 100644 --- a/declearn/model/api/_vector.py +++ b/declearn/model/api/_vector.py @@ -19,7 +19,10 @@ import operator from abc import ABCMeta, abstractmethod -from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union +from typing import ( + # fmt: off + Any, Callable, Dict, Generic, Optional, Set, Tuple, Type, TypeVar, Union +) from numpy.typing import ArrayLike from typing_extensions import Self # future: import from typing (Py>=3.11) @@ -37,10 +40,15 @@ __all__ = [ VECTOR_TYPES = {} # type: Dict[Type[Any], Type[Vector]] +"""Private constant holding registered Vector types.""" + + +T = TypeVar("T") +"""Type-annotation for the data structures proper to a given Vector class.""" @create_types_registry -class Vector(metaclass=ABCMeta): +class Vector(Generic[T], metaclass=ABCMeta): """Abstract class defining an API to manipulate (sets of) data arrays. A Vector is an abstraction used to wrap a collection of data @@ -62,27 +70,27 @@ class Vector(metaclass=ABCMeta): """ @property - def _op_add(self) -> Callable[[Any, Any], Any]: + def _op_add(self) -> Callable[[Any, Any], T]: """Framework-compatible addition operator.""" return operator.add @property - def _op_sub(self) -> Callable[[Any, Any], Any]: + def _op_sub(self) -> Callable[[Any, Any], T]: """Framework-compatible substraction operator.""" return operator.sub @property - def _op_mul(self) -> Callable[[Any, Any], Any]: + def _op_mul(self) -> Callable[[Any, Any], T]: """Framework-compatible multiplication operator.""" return operator.mul @property - def _op_div(self) -> Callable[[Any, Any], Any]: + def _op_div(self) -> Callable[[Any, Any], T]: """Framework-compatible true division operator.""" return operator.truediv @property - def _op_pow(self) -> Callable[[Any, Any], Any]: + def _op_pow(self) -> Callable[[Any, Any], T]: """Framework-compatible power operator.""" return operator.pow @@ -108,13 +116,13 @@ class Vector(metaclass=ABCMeta): def __init__( self, - coefs: Dict[str, Any], + coefs: Dict[str, T], ) -> None: """Instantiate the Vector to wrap a collection of data arrays. Parameters ---------- - coefs: dict[str, any] + coefs: dict[str, <T>] Dict grouping a named collection of data arrays. The supported types of that dict's values depends on the concrete `Vector` subclass being used. @@ -123,7 +131,7 @@ class Vector(metaclass=ABCMeta): @staticmethod def build( - coefs: Dict[str, Any], + coefs: Dict[str, T], ) -> "Vector": """Instantiate a Vector, inferring its exact subtype from coefs'. @@ -136,7 +144,7 @@ class Vector(metaclass=ABCMeta): Parameters ---------- - coefs: dict[str, any] + coefs: dict[str, <T>] Dict grouping a named collection of data arrays, that all belong to the same framework. @@ -189,7 +197,10 @@ class Vector(metaclass=ABCMeta): indexed by the coefficient's name. """ try: - return {key: coef.shape for key, coef in self.coefs.items()} + return { + key: coef.shape # type: ignore # exception caught + for key, coef in self.coefs.items() + } except AttributeError as exc: raise NotImplementedError( "Wrapped coefficients appear not to implement `.shape`.\n" @@ -210,7 +221,10 @@ class Vector(metaclass=ABCMeta): concrete framework of the Vector. """ try: - return {key: str(coef.dtype) for key, coef in self.coefs.items()} + return { + key: str(coef.dtype) # type: ignore # exception caught + for key, coef in self.coefs.items() + } except AttributeError as exc: raise NotImplementedError( "Wrapped coefficients appear not to implement `.dtype`.\n" @@ -261,7 +275,7 @@ class Vector(metaclass=ABCMeta): def apply_func( self, - func: Callable[..., Any], + func: Callable[..., T], *args: Any, **kwargs: Any, ) -> Self: @@ -290,14 +304,15 @@ class Vector(metaclass=ABCMeta): def _apply_operation( self, other: Any, - func: Callable[[Any, Any], Any], + func: Callable[[Any, Any], T], ) -> Self: """Apply an operation to combine this vector with another. Parameters ---------- - other: Vector - Vector with the same names, shapes and dtypes as this one. + other: + Vector with the same names, shapes and dtypes as this one; + or scalar object on which to operate (e.g. a float value). func: function(<T>, <T>) -> <T> Function to be applied to combine the data arrays stored in this vector and the `other` one. diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py index 17d19c87af034f41c9064dad7ac9460bf1cafb48..7ed66f802f670dd912102cdc33af24f94a5fbd66 100644 --- a/declearn/model/haiku/_model.py +++ b/declearn/model/haiku/_model.py @@ -22,10 +22,7 @@ import inspect import io import warnings from random import SystemRandom -from typing import ( - # fmt: off - Any, Callable, Dict, List, Optional, Set, Tuple, Union -) +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import haiku as hk import jax @@ -205,7 +202,7 @@ class HaikuModel(Model): params = {k: v for k, v in params.items() if k in self._trainable} return JaxNumpyVector(params) - def set_weights( # type: ignore # Vector subtype specification + def set_weights( self, weights: JaxNumpyVector, trainable: bool = False, @@ -466,7 +463,7 @@ class HaikuModel(Model): output = [list(map(convert, inputs)), convert(y_true), convert(s_wght)] return output # type: ignore - def apply_updates( # type: ignore # Vector subtype specification + def apply_updates( self, updates: JaxNumpyVector, ) -> None: diff --git a/declearn/model/haiku/_vector.py b/declearn/model/haiku/_vector.py index 74ac8cafe77c12898d9e56b1da7166da51163a65..e1f106f141848d5686297c23289d610e025ec1a2 100644 --- a/declearn/model/haiku/_vector.py +++ b/declearn/model/haiku/_vector.py @@ -74,23 +74,23 @@ class JaxNumpyVector(Vector): """ @property - def _op_add(self) -> Callable[[Any, Any], Any]: + def _op_add(self) -> Callable[[Any, Any], jax.Array]: return jnp.add @property - def _op_sub(self) -> Callable[[Any, Any], Any]: + def _op_sub(self) -> Callable[[Any, Any], jax.Array]: return jnp.subtract @property - def _op_mul(self) -> Callable[[Any, Any], Any]: + def _op_mul(self) -> Callable[[Any, Any], jax.Array]: return jnp.multiply @property - def _op_div(self) -> Callable[[Any, Any], Any]: + def _op_div(self) -> Callable[[Any, Any], jax.Array]: return jnp.divide @property - def _op_pow(self) -> Callable[[Any, Any], Any]: + def _op_pow(self) -> Callable[[Any, Any], jax.Array]: return jnp.power @property @@ -104,7 +104,7 @@ class JaxNumpyVector(Vector): def _apply_operation( self, other: Any, - func: Callable[[Any, Any], Any], + func: Callable[[jax.Array, Any], jax.Array], ) -> Self: # Ensure 'other' JaxNumpyVector shares this vector's device placement. if isinstance(other, JaxNumpyVector): diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py index 49da62cfb0e0e42ad69976c8a225c990aff805f9..9fba66ebc6b533867e8f86ef465d1f98a8246117 100644 --- a/declearn/model/sklearn/_np_vec.py +++ b/declearn/model/sklearn/_np_vec.py @@ -58,29 +58,35 @@ class NumpyVector(Vector): """ @property - def _op_add(self) -> Callable[[Any, Any], Any]: + def _op_add(self) -> Callable[[Any, Any], np.ndarray]: return np.add @property - def _op_sub(self) -> Callable[[Any, Any], Any]: + def _op_sub(self) -> Callable[[Any, Any], np.ndarray]: return np.subtract @property - def _op_mul(self) -> Callable[[Any, Any], Any]: + def _op_mul(self) -> Callable[[Any, Any], np.ndarray]: return np.multiply @property - def _op_div(self) -> Callable[[Any, Any], Any]: + def _op_div(self) -> Callable[[Any, Any], np.ndarray]: return np.divide @property - def _op_pow(self) -> Callable[[Any, Any], Any]: + def _op_pow(self) -> Callable[[Any, Any], np.ndarray]: return np.power - def __init__(self, coefs: Dict[str, np.ndarray]) -> None: + def __init__( + self, + coefs: Dict[str, np.ndarray], + ) -> None: super().__init__(coefs) - def __eq__(self, other: Any) -> bool: + def __eq__( + self, + other: Any, + ) -> bool: valid = isinstance(other, NumpyVector) if valid: valid = self.coefs.keys() == other.coefs.keys() diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index 18ad7877d2e9f27fae28e0ffaa8807d8cbfcb4e9..6ab4c88087b4cc775f17aa9dc3b6ec9f9d19b78f 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -282,7 +282,7 @@ class SklearnSGDModel(Model): } return NumpyVector(weights) - def set_weights( # type: ignore # Vector subtype specification + def set_weights( self, weights: NumpyVector, trainable: bool = False, @@ -356,7 +356,7 @@ class SklearnSGDModel(Model): # Compute gradients based on weights' update. return w_srt - w_end - def apply_updates( # type: ignore # Vector subtype specification + def apply_updates( self, updates: NumpyVector, ) -> None: diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index 581c285ca2d0ec80524b392821942d36613e5606..59260eaededbf9182178d82ec282c68964867e27 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -183,7 +183,7 @@ class TensorflowModel(Model): ) return TensorflowVector({var.name: var.value() for var in variables}) - def set_weights( # type: ignore # Vector subtype specification + def set_weights( self, weights: TensorflowVector, trainable: bool = False, @@ -319,7 +319,7 @@ class TensorflowModel(Model): outp.append(tf.reduce_mean(grad * s_wght, axis=0)) return outp - def apply_updates( # type: ignore # Vector subtype specification + def apply_updates( self, updates: TensorflowVector, ) -> None: diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 7491c3f53b444f8216143a39ce0d15249c810ebc..f8a5dd6f1396d6d6909edbf070ee616e6cc3a6ec 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -168,7 +168,7 @@ class TorchModel(Model): # Note: calling `tensor.clone()` to return a copy rather than a view. return TorchVector({k: t.detach().clone() for k, t in weights.items()}) - def set_weights( # type: ignore # Vector subtype specification + def set_weights( self, weights: TorchVector, trainable: bool = False, @@ -378,7 +378,7 @@ class TorchModel(Model): return grads_fn return functorch.compile.aot_function(grads_fn, functorch.compile.nop) - def apply_updates( # type: ignore # Vector subtype specification + def apply_updates( self, updates: TorchVector, ) -> None: diff --git a/declearn/optimizer/__init__.py b/declearn/optimizer/__init__.py index 520eeafa235b9b7ebce84c7327c96455deb09de0..44cc4396bfb6bfec5e264a18ce15b371948550d4 100644 --- a/declearn/optimizer/__init__.py +++ b/declearn/optimizer/__init__.py @@ -32,8 +32,16 @@ Submodules providing with plug-in algorithms: Gradients-alteration algorithms, implemented as plug-in modules. * [regularizers][declearn.optimizer.regularizers]: Loss-regularization algorithms, implemented as plug-in modules. + +Utils to list available plug-ins: + +* [list_optim_modules][declearn.optimizer.list_optim_modules]: + Return a mapping of registered OptiModule subclasses. +* [list_optim_regularizers][declearn.optimizer.list_optim_regularizers]: + Return a mapping of registered Regularizer subclasses. """ from . import modules, regularizers from ._base import Optimizer +from ._utils import list_optim_modules, list_optim_regularizers diff --git a/declearn/optimizer/_base.py b/declearn/optimizer/_base.py index 5e14779cfa5ce52a16fb4143b9a61064218b5aa2..01214cd26bb1db1cc3fb496a5805f4c148f59068 100644 --- a/declearn/optimizer/_base.py +++ b/declearn/optimizer/_base.py @@ -17,7 +17,10 @@ """Base class to define gradient-descent-based optimizers.""" -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + # fmt: off + Any, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union +) from typing_extensions import Self # future: import from typing (py >=3.11) @@ -31,6 +34,9 @@ __all__ = [ ] +T = TypeVar("T") + + class Optimizer: """Base class to define gradient-descent-based optimizers. @@ -103,6 +109,13 @@ class Optimizer: [1] Loshchilov & Hutter, 2019. Decoupled Weight Decay Regularization. https://arxiv.org/abs/1711.05101 + + See also + -------- + - [declearn.optimizer.list_optim_modules][]: + Return a mapping of registered OptiModule subclasses. + - [declearn.optimizer.list_optim_regularizers][]: + Return a mapping of registered Regularizer subclasses. """ def __init__( @@ -255,9 +268,9 @@ class Optimizer: def compute_updates_from_gradients( self, - model: Model, - gradients: Vector, - ) -> Vector: + model: Model[Vector[T]], + gradients: Vector[T], + ) -> Vector[T]: """Compute and return model updates based on pre-computed gradients. Parameters @@ -393,8 +406,8 @@ class Optimizer: def apply_gradients( self, - model: Model, - gradients: Vector, + model: Model[Vector[T]], + gradients: Vector[T], ) -> None: """Compute and apply model updates based on pre-computed gradients. diff --git a/declearn/optimizer/_utils.py b/declearn/optimizer/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf56ace356d8917843db4a4a5daa18ee28bb127 --- /dev/null +++ b/declearn/optimizer/_utils.py @@ -0,0 +1,86 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to list available optimizer plug-ins (OptiModule and Regularizer).""" + +from typing import Dict, Type + +from declearn.optimizer.modules import OptiModule +from declearn.optimizer.regularizers import Regularizer +from declearn.utils import access_types_mapping + + +__all__ = [ + "list_optim_modules", + "list_optim_regularizers", +] + + +def list_optim_modules() -> Dict[str, Type[OptiModule]]: + """Return a mapping of registered OptiModule subclasses. + + This function aims at making it easy for end-users to list and access + all available OptiModule optimizer plug-ins at any given time. The + returned dict uses unique identifier keys, which may be used to add + the associated plug-in to a [declearn.optimizer.Optimizer][] without + going through the fuss of importing and instantiating it manually. + + Note that the mapping will include all declearn-provided plug-ins, + but also registered plug-ins provided by user or third-party code. + + See also + -------- + * [declearn.optimizer.modules.OptiModule][]: + API-defining abstract base class for the OptiModule plug-ins. + * [declearn.optimizer.list_optim_regularizers][]: + Counterpart function for Regularizer plug-ins. + + Returns + ------- + mapping: + Dictionary mapping unique str identifiers to OptiModule + class constructors. + """ + return access_types_mapping("OptiModule") + + +def list_optim_regularizers() -> Dict[str, Type[Regularizer]]: + """Return a mapping of registered Regularizer subclasses. + + This function aims at making it easy for end-users to list and access + all available Regularizer optimizer plug-ins at any given time. The + returned dict uses unique identifier keys, which may be used to add + the associated plug-in to a [declearn.optimizer.Optimizer][] without + going through the fuss of importing and instantiating it manually. + + Note that the mapping will include all declearn-provided plug-ins, + but also registered plug-ins provided by user or third-party code. + + See also + -------- + * [declearn.optimizer.regularizers.Regularizer][]: + API-defining abstract base class for the Regularizer plug-ins. + * [declearn.optimizer.list_optim_modules][]: + Counterpart function for OptiModule plug-ins. + + Returns + ------- + mapping: + Dictionary mapping unique str identifiers to Regularizer + class constructors. + """ + return access_types_mapping("Regularizer") diff --git a/declearn/optimizer/modules/_api.py b/declearn/optimizer/modules/_api.py index 197346e0bc425f94e053cf336d3eb05bfd5e5468..55d3f37cfa6c049f2628d8581e2803590a4d6b2b 100644 --- a/declearn/optimizer/modules/_api.py +++ b/declearn/optimizer/modules/_api.py @@ -18,7 +18,7 @@ """Base API for plug-in gradients-alteration algorithms.""" from abc import ABCMeta, abstractmethod -from typing import Any, ClassVar, Dict, Optional +from typing import Any, ClassVar, Dict, Optional, TypeVar from typing_extensions import Self # future: import from typing (py >=3.11) @@ -34,6 +34,9 @@ __all__ = [ ] +T = TypeVar("T") + + @create_types_registry class OptiModule(metaclass=ABCMeta): """Abstract class defining an API to implement gradients adaptation tools. @@ -117,8 +120,8 @@ class OptiModule(metaclass=ABCMeta): @abstractmethod def run( self, - gradients: Vector, - ) -> Vector: + gradients: Vector[T], + ) -> Vector[T]: """Apply the module's algorithm to input gradients. Please refer to the module's main docstring for details diff --git a/declearn/optimizer/modules/_noise.py b/declearn/optimizer/modules/_noise.py index f593a6bce116a0ef60bb51fc57c86bfb5d21660c..2409e76e819c31d2d3d57d433de6d15e0c5f724e 100644 --- a/declearn/optimizer/modules/_noise.py +++ b/declearn/optimizer/modules/_noise.py @@ -81,7 +81,7 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False): gradients: Vector, ) -> Vector: if not NumpyVector in gradients.compatible_vector_types: - raise TypeError( + raise TypeError( # pragma: no cover f"{self.__class__.__name__} requires input gradients to " "be compatible with NumpyVector, which is not the case " f"of {type(gradients).__name__}." @@ -95,7 +95,7 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False): for key in gradients.coefs } # Add the sampled noise to the gradients and return them. - # Silence warnings about sparse gradients getting sparsified. + # Silence warnings about sparse gradients getting densified. with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*densifying.*", RuntimeWarning) return gradients + NumpyVector(noise) @@ -166,4 +166,6 @@ class GaussianNoiseModule(NoiseModule): # false-positive; pylint: disable=no-member return self._rng.normal(scale=self.std, size=shape).astype(dtype) # Theoretically-unreachable case. - raise RuntimeError("Unexpected `GaussianeNoiseModule._rng` type.") + raise RuntimeError( # pragma: no cover + "Unexpected `GaussianeNoiseModule._rng` type." + ) diff --git a/declearn/optimizer/regularizers/_api.py b/declearn/optimizer/regularizers/_api.py index 1a9dfa5bc2151971a101c48e11ce7c480d4ec7ca..187f5eb5ef2875b11f17da7dfda697fe4c50c38c 100644 --- a/declearn/optimizer/regularizers/_api.py +++ b/declearn/optimizer/regularizers/_api.py @@ -18,7 +18,7 @@ """Base API for loss regularization optimizer plug-ins.""" from abc import ABCMeta, abstractmethod -from typing import Any, ClassVar, Dict +from typing import Any, ClassVar, Dict, TypeVar from typing_extensions import Self # future: import from typing (py >=3.11) @@ -34,6 +34,9 @@ __all__ = [ ] +T = TypeVar("T") + + @create_types_registry class Regularizer(metaclass=ABCMeta): """Abstract class defining an API to implement loss-regularizers. @@ -115,9 +118,9 @@ class Regularizer(metaclass=ABCMeta): @abstractmethod def run( self, - gradients: Vector, - weights: Vector, - ) -> Vector: + gradients: Vector[T], + weights: Vector[T], + ) -> Vector[T]: """Compute and add the regularization term's derivative to gradients. Parameters diff --git a/declearn/quickrun/_run.py b/declearn/quickrun/_run.py index 293cf5297c9b9c8f3236c5a1bef9d50f8c4a1089..c33ae875dd1af0f5df6ca02e610877bcee62818a 100644 --- a/declearn/quickrun/_run.py +++ b/declearn/quickrun/_run.py @@ -175,7 +175,7 @@ def server_to_client_network( "Convert server network config to client network config." return NetworkClientConfig.from_params( protocol=network_cfg.protocol, - server_uri=f"ws://localhost:{network_cfg.port}", + server_uri=network_cfg.build_server().uri, name="replaceme", ) diff --git a/declearn/utils/_multiprocess.py b/declearn/utils/_multiprocess.py index a39ac540d7d1001e466d3560ce3026edee38fca9..8e4b7ca3cbac3585ba444ee6f542b3c6c291a774 100644 --- a/declearn/utils/_multiprocess.py +++ b/declearn/utils/_multiprocess.py @@ -22,7 +22,7 @@ import multiprocessing as mp import sys import traceback from queue import Queue -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union __all__ = [ "run_as_processes", @@ -105,12 +105,9 @@ def run_as_processes( def add_exception_catching( func: Callable[..., Any], queue: Queue, - name: Optional[str] = None, + name: str, ) -> Callable[..., Any]: """Wrap a function to catch exceptions and put them in a Queue.""" - if not name: - name = func.__name__ - return functools.partial(wrapped, func=func, queue=queue, name=name) diff --git a/declearn/utils/_toml_config.py b/declearn/utils/_toml_config.py index 012a2255d09a48b58314e5dc2ddc789255b4754b..adc3c2b87cd6fbf0825b37ff35e4905f032e8b98 100644 --- a/declearn/utils/_toml_config.py +++ b/declearn/utils/_toml_config.py @@ -18,6 +18,7 @@ """Base class to define TOML-parsable configuration containers.""" import dataclasses +import os import typing import warnings @@ -99,7 +100,7 @@ def _isinstance_generic(inputs: Any, typevar: Type) -> bool: and all(_isinstance_generic(e, t) for e, t in zip(inputs, args)) ) # Unsupported cases. - raise TypeError( + raise TypeError( # pragma: no cover "Unsupported subscripted generic for instance check: " f"'{typevar}' with origin '{origin}'." ) @@ -112,26 +113,38 @@ def _parse_float(src: str) -> Optional[float]: def _instantiate_field( field: dataclasses.Field, # future: dataclasses.Field[T] (Py >=3.9) - *args: Any, **kwargs: Any, ) -> Any: # future: T """Instantiate a dataclass field from input args and kwargs. - This functions is meant to enable automatically building dataclass - fields that are annotated to be a union of types, notably optional - fields (i.e. Union[T, None]). + This function is meant to enable building dataclass object fields, + that requring instantiating from a received value or dict of kwargs. + + It supports doing so even when for fields that are annotated to be a + union of types (notably optional ones: Union[T, None]), and/or are a + `TomlConfig` subclass that should be instantiated via `from_params` + rather than `cls(*args, **kwargs)`. It will raise a TypeError if instantiation fails or if `field.type` has and unsupported typing origin. It may also raise any exception coming from the target type's `__init__` method. """ + + def _instantiate(cls: Type[Any]) -> Any: + """Try instantiating a given class from the kwargs.""" + if issubclass(cls, TomlConfig): + return cls.from_params(**kwargs) + return cls(**kwargs) + origin = typing.get_origin(field.type) - if origin is None: # raw type - return field.type(*args, **kwargs) - if origin is Union: # union of types, including optional + # Case of a raw type. + if origin is None: + return _instantiate(field.type) + # Case of a union of types (including optional). + if origin is Union: for cls in typing.get_args(field.type): try: - return cls(*args, **kwargs) + return _instantiate(cls) except TypeError: pass raise TypeError( @@ -219,12 +232,14 @@ class TomlConfig: for key in kwargs: warnings.warn( f"Unsupported keyword argument in {cls.__name__}.from_params: " - f"'{key}'. This argument was ignored." + f"'{key}'. This argument was ignored.", + category=RuntimeWarning, ) return cls(**fields) - @staticmethod + @classmethod def default_parser( + cls, field: dataclasses.Field, # future: dataclasses.Field[T] (Py >=3.9) inputs: Union[str, Dict[str, Any], T, None], ) -> Any: @@ -272,20 +287,19 @@ class TomlConfig: raise TypeError( f"Field '{field.name}' does not provide a default value." ) - # Case of str inputs: treat as the path to a TOML file to parse. - if isinstance(inputs, str): - # If the field implements TOML parsing, call it. - if issubclass(field.type, TomlConfig): - return field.type.from_toml(inputs) - # Otherwise, conduct minimal parsing. - with open(inputs, "rb") as file: - config = tomllib.load(file, parse_float=_parse_float) - section = config.get(field.name, config) # subsection or full file - return ( - _instantiate_field(field, **section) - if isinstance(section, dict) - else _instantiate_field(field, section) - ) + # Case of str inputs poiting to a file: try parsing it. + if isinstance(inputs, str) and os.path.isfile(inputs): + try: + with open(inputs, "rb") as file: + config = tomllib.load(file, parse_float=_parse_float) + except tomllib.TOMLDecodeError as exc: + raise TypeError( + f"Field {field.name}: could not parse secondary TOML file." + ) from exc + # Look for a subsection, otherwise keep the entire file. + section = config.get(field.name, config) + # Recursively call this parser. + return cls.default_parser(field, section) # Case of dict inputs: try instantiating the target type. if isinstance(inputs, dict): return _instantiate_field(field, **inputs) @@ -377,7 +391,8 @@ class TomlConfig: for name in config: warnings.warn( f"Unsupported section encountered in {path} TOML file: " - f"'{name}'. This section will be ignored." + f"'{name}'. This section will be ignored.", + category=RuntimeWarning, ) # Finally, instantiate the FLConfig container. return cls.from_params(**params) diff --git a/test/aggregator/test_aggregator.py b/test/aggregator/test_aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b02def27aee7900813a73bef35d4dfaef8d24b --- /dev/null +++ b/test/aggregator/test_aggregator.py @@ -0,0 +1,93 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the 'Aggregator' subclasses.""" + +import typing +from typing import Dict, Type + +import pytest + +from declearn.aggregator import Aggregator, list_aggregators +from declearn.model.api import Vector +from declearn.test_utils import ( + FrameworkType, + GradientsTestCase, + assert_dict_equal, + assert_json_serializable_dict, +) + + +AGGREGATOR_CLASSES = list_aggregators() +VECTOR_FRAMEWORKS = typing.get_args(FrameworkType) + + +@pytest.fixture(name="updates") +def updates_fixture( + framework: FrameworkType, + n_clients: int = 3, +) -> Dict[str, Vector]: + """Fixture providing with deterministic sets of updates Vector.""" + return { + str(idx): GradientsTestCase(framework, seed=idx).mock_gradient + for idx in range(n_clients) + } + + +@pytest.mark.parametrize( + "agg_cls", AGGREGATOR_CLASSES.values(), ids=AGGREGATOR_CLASSES.keys() +) +class TestAggregator: + """Shared unit tests suite for 'Aggregator' subclasses.""" + + @pytest.mark.parametrize("framework", VECTOR_FRAMEWORKS) + def test_aggregate( + self, + agg_cls: Type[Aggregator], + updates: Dict[str, Vector], + ) -> None: + """Test that the 'aggregate' method works properly.""" + agg = agg_cls() + n_steps = {key: 10 for key in updates} + outputs = agg.aggregate(updates, n_steps) + ref_vec = list(updates.values())[0] + assert isinstance(outputs, type(ref_vec)) + assert outputs.shapes() == ref_vec.shapes() + assert outputs.dtypes() == ref_vec.dtypes() + + def test_aggregate_empty( + self, + agg_cls: Type[Aggregator], + ) -> None: + """Test that 'aggregate' raises the expected error on empty inputs.""" + agg = agg_cls() + with pytest.raises(TypeError): + agg.aggregate(updates={}, n_steps={}) + + def test_get_config(self, agg_cls: Type[Aggregator]) -> None: + """Test that the 'get_config' method works properly.""" + agg = agg_cls() + cfg = agg.get_config() + assert_json_serializable_dict(cfg) + + def test_from_config(self, agg_cls: Type[Aggregator]) -> None: + """Test that the 'from_config' method works properly.""" + agg = agg_cls() + cfg = agg.get_config() + bis = agg_cls.from_config(cfg) + assert isinstance(bis, agg_cls) + assert_dict_equal(cfg, bis.get_config()) diff --git a/test/functional/test_quickrun.py b/test/functional/test_quickrun.py new file mode 100644 index 0000000000000000000000000000000000000000..5f05d97f0f4952c3721c9c0ff3e65b41f14b09ec --- /dev/null +++ b/test/functional/test_quickrun.py @@ -0,0 +1,82 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional test of the declearn quickrun example.""" + +import os +import pathlib + +import numpy as np + +from declearn.dataset import split_data +from declearn.quickrun import quickrun + + +MODEL_CODE = """ +from declearn.model.sklearn import SklearnSGDModel + +model = SklearnSGDModel.from_parameters(kind="classifier", penalty="none") +""" + +CONFIG_TOML = """ +[network] +protocol = "websockets" +host = "127.0.0.1" +port = 8080 + +[data] + +[optim] +[optim.client_opt] +lrate = 0.01 +modules = ["adam"] +regularizers = ["lasso"] + +[run] +rounds = 2 +[run.register] +min_clients = 2 +[run.training] +batch_size = 48 +n_steps = 100 +[run.evaluate] +batch_size = 128 + +[experiment] +metrics = [ + ["multi-classif", {labels = [0, 1, 2, 3, 4, 5, 6, 6, 7, 8, 9]}] +] +""" + + +def test_quickrun_mnist(tmp_path: str) -> None: + """Run a very basic MNIST example using 'declearn-quickrun'.""" + # Download, prepare and split the MNIST dataset into iid shards. + split_data(tmp_path, n_shards=2, seed=0) + # Flatten out the input images to enable their processing with sklearn. + for path in pathlib.Path(tmp_path).glob("data_iid/client_*/*_data.npy"): + images = np.load(path) + np.save(path, images.reshape((-1, 28 * 28))) + # Write down a very basic TOML config and python model files. + model = os.path.join(tmp_path, "model.py") + with open(model, "w", encoding="utf-8") as file: + file.write(MODEL_CODE) + config = os.path.join(tmp_path, "config.toml") + with open(config, "w", encoding="utf-8") as file: + file.write(CONFIG_TOML) + # Run the quickrun experiment. + quickrun(config) diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py index a6da92e767b7db9760357a00777c46116d13e577..0387075f66dab9bad3635d033e76775c56666624 100644 --- a/test/model/test_tflow.py +++ b/test/model/test_tflow.py @@ -183,7 +183,7 @@ class TestTensorflowModel(ModelTestSuite): ) -> None: """Check that `get_weights` behaves properly with frozen weights.""" model = test_case.model - tfmod = getattr(model, "_model") # type: tf.keras.Sequential + tfmod = model.get_wrapped_model() tfmod.layers[0].trainable = False # freeze the first layer's weights w_all = model.get_weights() w_trn = model.get_weights(trainable=True) @@ -198,7 +198,7 @@ class TestTensorflowModel(ModelTestSuite): """Check that `set_weights` behaves properly with frozen weights.""" # Setup a model with some frozen weights, and gather trainable ones. model = test_case.model - tfmod = getattr(model, "_model") # type: tf.keras.Sequential + tfmod = model.get_wrapped_model() tfmod.layers[0].trainable = False # freeze the first layer's weights w_trn = model.get_weights(trainable=True) # Test that `set_weights` works if and only if properly parametrized. @@ -217,7 +217,7 @@ class TestTensorflowModel(ModelTestSuite): policy = model.device_policy assert policy.gpu == (test_case.device == "GPU") assert policy.idx == 0 - tfmod = getattr(model, "_model") + tfmod = model.get_wrapped_model() device = f"{test_case.device}:0" for var in tfmod.weights: assert var.device.endswith(device) diff --git a/test/model/test_torch.py b/test/model/test_torch.py index 67bdc5cce547fcd98cf7b64a18fed9d81ddfb8d5..dafe8c8cb66b1fbfe12a7201d4bb35a0c4eb23b5 100644 --- a/test/model/test_torch.py +++ b/test/model/test_torch.py @@ -236,8 +236,8 @@ class TestTorchModel(ModelTestSuite): # Verify that both models have the same device policy. assert model.device_policy == other.device_policy # Verify that both models have a similar structure of modules. - mod_a = list(getattr(model, "_model").modules()) - mod_b = list(getattr(other, "_model").modules()) + mod_a = list(model.get_wrapped_model().modules()) + mod_b = list(other.get_wrapped_model().modules()) assert len(mod_a) == len(mod_b) assert all(isinstance(a, type(b)) for a, b in zip(mod_a, mod_b)) assert all(repr(a) == repr(b) for a, b in zip(mod_a, mod_b)) @@ -262,7 +262,7 @@ class TestTorchModel(ModelTestSuite): ) -> None: """Check that `get_weights` behaves properly with frozen weights.""" model = test_case.model - ptmod = getattr(model, "_model") # type: torch.nn.Module + ptmod = model.get_wrapped_model() next(ptmod.parameters()).requires_grad = False # freeze some weights w_all = model.get_weights() w_trn = model.get_weights(trainable=True) @@ -280,7 +280,7 @@ class TestTorchModel(ModelTestSuite): """Check that `set_weights` behaves properly with frozen weights.""" # Setup a model with some frozen weights, and gather trainable ones. model = test_case.model - ptmod = getattr(model, "_model") # type: torch.nn.Module + ptmod = model.get_wrapped_model() next(ptmod.parameters()).requires_grad = False # freeze some weights w_trn = model.get_weights(trainable=True) # Test that `set_weights` works if and only if properly parametrized. @@ -299,7 +299,7 @@ class TestTorchModel(ModelTestSuite): policy = model.device_policy assert policy.gpu == (test_case.device == "GPU") assert (policy.idx == 0) if policy.gpu else (policy.idx is None) - ptmod = getattr(model, "_model").module + ptmod = model.get_wrapped_model().module device_type = "cpu" if test_case.device == "CPU" else "cuda" for param in ptmod.parameters(): assert param.device.type == device_type diff --git a/test/utils/test_register.py b/test/utils/test_register.py index 2cf5085c0062835c0d6a8b6df169611275fcbffe..5d47b4f4ab6a1c6e1fa09592b59a8c9d7e30fe32 100644 --- a/test/utils/test_register.py +++ b/test/utils/test_register.py @@ -33,9 +33,13 @@ from declearn.utils import ( def test_create_types_registry() -> None: """Unit tests for 'create_types_registry'.""" group = f"test_{time.time_ns()}" - assert create_types_registry(object, group) is object + + class AnyClass: # pylint: disable=all + pass + + assert create_types_registry(AnyClass, group) is AnyClass with pytest.raises(KeyError): - create_types_registry(object, group) + create_types_registry(AnyClass, group) def test_register_type() -> None: @@ -75,6 +79,9 @@ def test_register_type_fails() -> None: group = f"test_{time.time_ns()}" with pytest.raises(KeyError): register_type(BaseClass, name="base", group=group) + # Try registering in any group, with no valid parent group. + with pytest.raises(TypeError): + register_type(BaseClass, name="base", group=None) # Try registering in a group with wrong class constraints. create_types_registry(BaseClass, group) with pytest.raises(TypeError): @@ -103,10 +110,31 @@ def test_access_registered() -> None: name_2 = f"test_{time.time_ns()}" with pytest.raises(KeyError): access_registered(name_2, group=name) # invalid name under group + with pytest.raises(KeyError): + access_registered(name_2, group=None) # invalid name under any group with pytest.raises(KeyError): access_registered(name, group=name_2) # non-existing group +def test_register_unspecified_group() -> None: + """Unit tests for type-registration with implicit group membership.""" + group = f"test_{time.time_ns()}" + + # Define a parent class and an associted type registry. + @create_types_registry(name=group) + class Parent: # pylint: disable=all + pass + + # Define a child class, and register it without specifying the group. + @register_type(name="new-child") + class Child(Parent): # pylint: disable=all + pass + + # Verify that the class was put into the proper group. + assert access_registered("new-child") is Child + assert access_registration_info(Child) == ("new-child", group) + + def test_access_registeration_info() -> None: """Unit tests for 'access_registration_info'.""" @@ -157,3 +185,7 @@ def test_access_types_mapping() -> None: assert mapping != access_types_mapping(group=group) with pytest.raises(KeyError): access_registered("renamed", group=group) + + # Test that the expected exception is raised for non-existing groups. + with pytest.raises(KeyError): + access_types_mapping(group=f"test_{time.time_ns()}") diff --git a/test/utils/test_toml.py b/test/utils/test_toml.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8a04d981648591e1c9a30536daae04cd2af2c5 --- /dev/null +++ b/test/utils/test_toml.py @@ -0,0 +1,383 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the `TomlConfig` util.""" + +import dataclasses +import os +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytest +from typing_extensions import Self + +from declearn.utils import TomlConfig + + +class Custom: + """Custom class that requires specific TOML parsing.""" + + def __init__( + self, + value: Any = "default", + _from_config: bool = False, + ) -> None: + """Default builder for Custom instances.""" + self.value = value + self.f_cfg = _from_config + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Custom): + return self.value == other.value + return NotImplemented + + @classmethod + def from_config(cls, **kwargs: Any) -> Self: + """Alternative builder for Custom instances.""" + return cls(_from_config=True, **kwargs) + + +@dataclasses.dataclass +class DemoTomlConfig(TomlConfig): + """Demonstration TomlConfig subclass.""" + + req_int: int + req_lst: List[str] + opt_str: str = "default" + opt_tup: Optional[Tuple[int, int]] = None + opt_dct: Dict[str, float] = dataclasses.field(default_factory=dict) + opt_obj: Custom = dataclasses.field(default_factory=Custom) + opt_unc: Union[str, Custom, None] = None + + @classmethod + def parse_opt_tup( + cls, + field: dataclasses.Field, + inputs: Any, + ) -> Optional[Tuple[int, int]]: + """Custom parser for `opt_tup`, adding list-to-tuple conversion.""" + if isinstance(inputs, list): + inputs = tuple(inputs) + return cls.default_parser(field, inputs) + + @classmethod + def get_field(cls, name: str) -> dataclasses.Field: + """Access the definition of a given dataclass field.""" + return {field.name: field for field in dataclasses.fields(cls)}[name] + + +class TestTomlConfigDefaultParser: + """Unit tests for `TomlConfig.default_parser`, using a demo subclass.""" + + def test_int(self) -> None: + """Test that the parser works for an int field.""" + field = DemoTomlConfig.get_field("req_int") + assert TomlConfig.default_parser(field, 42) == 42 + with pytest.raises(TypeError): + TomlConfig.default_parser(field, 42.0) + + def test_lst(self) -> None: + """Test that the parser works for a list of str field.""" + field = DemoTomlConfig.get_field("req_lst") + # Test with a list of str. + value = ["this", "is", "a", "test"] + assert TomlConfig.default_parser(field, value) is value + # Test with an empty list. + value = [] + assert TomlConfig.default_parser(field, value) is value + # Test with a single str. + with pytest.raises(TypeError): + TomlConfig.default_parser(field, None) + # Test with a mixed-type list. + with pytest.raises(TypeError): + TomlConfig.default_parser(field, ["this", "fails", 0]) + + def test_opt_str(self) -> None: + """Test that the parser works for an optional str field.""" + field = DemoTomlConfig.get_field("opt_str") + # Test without a value. + assert TomlConfig.default_parser(field, None) is field.default + # Test with a valid value. + assert TomlConfig.default_parser(field, "test") == "test" + # Test with an invalid value. + with pytest.raises(TypeError): + TomlConfig.default_parser(field, 0) + + def test_opt_tup(self) -> None: + """Test that the parser works for an optional tuple of int field.""" + field = DemoTomlConfig.get_field("opt_tup") + # Test without a value. + assert TomlConfig.default_parser(field, None) is field.default + # Test with a valid value. + value = (12, 15) + assert TomlConfig.default_parser(field, value) is value + # Test with invalid values (wrong type, length or internal types). + with pytest.raises(TypeError): + TomlConfig.default_parser(field, [12, 15]) + with pytest.raises(TypeError): + TomlConfig.default_parser(field, (12, 15, 18)) + with pytest.raises(TypeError): + TomlConfig.default_parser(field, (12, "15")) + + def test_opt_dct(self) -> None: + """Test that the parser works for an optional dict field.""" + field = DemoTomlConfig.get_field("opt_dct") + # Test without a value. + assert TomlConfig.default_parser(field, None) == {} + # Test with a valid value. + value = {"a": 0.0, "b": 1.0} + assert TomlConfig.default_parser(field, value) is value + # Test with invalid values (wrong type, key types or value types). + with pytest.raises(TypeError): + TomlConfig.default_parser(field, "test") + with pytest.raises(TypeError): + TomlConfig.default_parser(field, {0: 0.0}) # type: ignore + with pytest.raises(TypeError): + TomlConfig.default_parser(field, {"a": "val"}) + + def test_opt_obj(self) -> None: + """Test that the parser works for an optional custom object field.""" + field = DemoTomlConfig.get_field("opt_obj") + # Test without a value. + assert TomlConfig.default_parser(field, None) == Custom() + # Test with a valid value. + value = Custom(value=42.0) + assert TomlConfig.default_parser(field, value) is value + # Test with kwargs for the object itself. + built = TomlConfig.default_parser(field, {"value": 0.0}) + assert isinstance(built, Custom) + assert built.value == 0.0 + # Test with an invalid value. + with pytest.raises(TypeError): + TomlConfig.default_parser(field, "invalid") + + def test_opt_unc(self) -> None: + """Test that the parser works for a union with custom types.""" + field = DemoTomlConfig.get_field("opt_unc") + # Test without a value. + assert TomlConfig.default_parser(field, None) is None + # Test with a valid str value. + assert TomlConfig.default_parser(field, "unc") == "unc" + # Test with a valid Custom object.class TestTomlConfig: + value = Custom(value="test") + assert TomlConfig.default_parser(field, value) is value + # Test with kwargs for a Custom object. + built = TomlConfig.default_parser(field, {"value": "test"}) + assert isinstance(built, Custom) + assert built.value == "test" + # Test with an invalid value. + with pytest.raises(TypeError): + TomlConfig.default_parser(field, {"invalid": "kwarg"}) + + +class TestTomlConfigFromParams: + """Unit tests for `TomlConfig.from_params`, using a demo subclass.""" + + exhaustive_params = { + "req_int": 0, + "req_lst": ["test"], + "opt_str": "test", + "opt_tup": (0, 1), + "opt_dct": {"key": 0.0}, + "opt_obj": Custom(value=0), + "opt_unc": Custom(value=1), + } + + def test_all_params(self) -> None: + """Test that parsing from an exhaustive dict of valid params works.""" + parsed = DemoTomlConfig.from_params(**self.exhaustive_params) + assert isinstance(parsed, DemoTomlConfig) + for key, val in self.exhaustive_params.items(): + assert getattr(parsed, key) == val + + def test_partial_params(self) -> None: + """Test that parsing without some optional params works.""" + params = {"req_int": 0, "req_lst": ["test"]} + parsed = DemoTomlConfig.from_params(**params) + assert isinstance(parsed, DemoTomlConfig) + assert all(getattr(parsed, key) == val for key, val in params.items()) + + def test_bad_params(self) -> None: + """Test that parsing with some bad params fails.""" + # Missing required keys. + with pytest.raises(RuntimeError): + DemoTomlConfig.from_params(req_int=0, opt_str="test") + # Invalid value types. + with pytest.raises(RuntimeError): + DemoTomlConfig.from_params(req_int=0, req_lst=1) + + def test_extra_params(self) -> None: + """Test that provided with extra parameters raises a warning.""" + with pytest.warns(RuntimeWarning): + DemoTomlConfig.from_params(req_int=0, req_lst=["1"], extra=2) + + +class TestTomlConfigFromToml: + """Unit tests for `TomlConfig.from_toml`, using a demo subclass.""" + + exhaustive_toml = """ + req_int = 0 + req_lst = ["test"] + opt_str = "test" + opt_tup = [0, 1] + opt_dct = {key = 0.0} + opt_obj = {value = 0} + opt_unc = {value = 1} + """ + + def test_from_exhaustive_toml(self, tmp_path: str) -> None: + """Test that parsing from an exhaustive and valid TOML file works.""" + # Export the TOML config file. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write(self.exhaustive_toml) + # Parse it and verify the outputs. + parsed = DemoTomlConfig.from_toml(path) + assert isinstance(parsed, DemoTomlConfig) + for key, val in TestTomlConfigFromParams.exhaustive_params.items(): + assert getattr(parsed, key) == val + + def test_wrong_file_fails(self, tmp_path: str) -> None: + """Test that a proper error is raised when parsing an invalid file.""" + # Export the non-TOML file. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write("this is not a TOML file") + # Verify that the parsing error is properly caught and wrapped. + with pytest.raises(RuntimeError): + DemoTomlConfig.from_toml(path) + + def test_warn_user(self, tmp_path: str) -> None: + """Test that the 'warn_user' boolean flag works properly.""" + # Export the TOML config file, with an extra field. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write(self.exhaustive_toml + "\nunused = nan") + # Parse it, expecting a warning. + with pytest.warns(RuntimeWarning): + DemoTomlConfig.from_toml(path) + # Parse it again, expecting no warning. + with warnings.catch_warnings(): + warnings.simplefilter("error") + DemoTomlConfig.from_toml(path, warn_user=False) + + def test_use_section(self, tmp_path: str) -> None: + """Test that the 'use_section' option works properly.""" + # Export the TOML config file, within a wrapping section. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write("[mysection]\n" + self.exhaustive_toml) + # Verify that a basic attempt at parsing fails. + with pytest.raises(RuntimeError): + DemoTomlConfig.from_toml(path) + # Parse it and verify the outputs. + parsed = DemoTomlConfig.from_toml(path, use_section="mysection") + assert isinstance(parsed, DemoTomlConfig) + for key, val in TestTomlConfigFromParams.exhaustive_params.items(): + assert getattr(parsed, key) == val + + def test_use_section_fails(self, tmp_path: str) -> None: + """Test that the 'use_section' option fails properly.""" + # Export the TOML config file. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write(self.exhaustive_toml) + # Verify that targetting a non-existing section fails. + with pytest.raises(KeyError): + DemoTomlConfig.from_toml(path, use_section="mysection") + # Verify that the error can be skipped using `section_fails_ok`. + parsed = DemoTomlConfig.from_toml( + path, use_section="mysection", section_fail_ok=True + ) + assert isinstance(parsed, DemoTomlConfig) + for key, val in TestTomlConfigFromParams.exhaustive_params.items(): + assert getattr(parsed, key) == val + + +@dataclasses.dataclass +class ComplexTomlConfig(TomlConfig): + """Demonstration TomlConfig subclass with some nestedness.""" + + demo_a: DemoTomlConfig + demo_b: DemoTomlConfig + + +class TestTomlConfigNested: + """Unit test for a complex, nested 'TomlConfig' subclass.""" + + def test_multisection_config(self, tmp_path: str) -> None: + """Test parsing a multi-section TOML file using nested parsers.""" + # Export the multi-section TOML config file. + path = os.path.join(tmp_path, "config.toml") + toml = ( + f"[demo_a]\n{TestTomlConfigFromToml.exhaustive_toml}\n" + f"[demo_b]\n{TestTomlConfigFromToml.exhaustive_toml}\n" + ) + with open(path, "w", encoding="utf-8") as file: + file.write(toml) + # Verify that is can be properly parsed. + parsed = ComplexTomlConfig.from_toml(path) + assert isinstance(parsed, ComplexTomlConfig) + for fname in ("demo_a", "demo_b"): + field = getattr(parsed, fname) + assert isinstance(field, DemoTomlConfig) + for key, val in TestTomlConfigFromParams.exhaustive_params.items(): + assert getattr(field, key) == val + + def test_multifiles_config(self, tmp_path: str) -> None: + """Test parsing a multi-files TOML config using nested parsers.""" + # Export the multi-files TOML config. + # demo_a: section file + path_a = os.path.join(tmp_path, "demo_a.toml") + with open(path_a, "w", encoding="utf-8") as file: + file.write(f"[demo_a]\n{TestTomlConfigFromToml.exhaustive_toml}") + # demo_b: full-file + path_b = os.path.join(tmp_path, "demo_b.toml") + with open(path_b, "w", encoding="utf-8") as file: + file.write(TestTomlConfigFromToml.exhaustive_toml) + # main config file + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write(f"demo_a = '{path_a}'\ndemo_b = '{path_b}'") + # Verify that is can be properly parsed. + parsed = ComplexTomlConfig.from_toml(path) + assert isinstance(parsed, ComplexTomlConfig) + for fname in ("demo_a", "demo_b"): + field = getattr(parsed, fname) + assert isinstance(field, DemoTomlConfig) + for key, val in TestTomlConfigFromParams.exhaustive_params.items(): + assert getattr(field, key) == val + + def test_multifiles_config_fails(self, tmp_path: str) -> None: + """Test parsing a multi-files TOML config with invalid contents.""" + # Export a badly-formatted secondary file. + path_bad = os.path.join(tmp_path, "bad.toml") + with open(path_bad, "w", encoding="utf-8") as file: + file.write("this file is not a TOML one") + # Export a main config file that points to the unproper one. + path = os.path.join(tmp_path, "config.toml") + with open(path, "w", encoding="utf-8") as file: + file.write(f"demo_a = '{path_bad}'\ndemo_b = '{path_bad}'") + # Verify the the parsing error is properly caught and wrapped. + with pytest.raises(RuntimeError): + ComplexTomlConfig.from_toml(path) + field = { + field.name: field + for field in dataclasses.fields(ComplexTomlConfig) + }["demo_a"] + with pytest.raises(TypeError): + ComplexTomlConfig.default_parser(field, path_bad) diff --git a/tox.ini b/tox.ini index 664ce9eebdf81f9942cfb2f10efb6f5b3b7f972e..5977c73a372f23e47931979e4a3d4eea01f0c184 100644 --- a/tox.ini +++ b/tox.ini @@ -15,12 +15,6 @@ commands= --cov --cov-report= \ # reset then accumulate coverage quietly --ignore=test/functional/ \ test - # run functional tests (that build on units) - pytest {posargs} \ - --cov --cov-append --cov-report=term \ # acc. and display coverage - test/functional/ - # export the finalized coverage report to xml - coverage xml # verify code acceptance by pylint pylint declearn pylint --recursive=y test @@ -29,6 +23,12 @@ commands= # verify code formatting black --check declearn black --check test + # run functional ~ integration tests (that build on unit ones) + pytest {posargs} \ + --cov --cov-append --cov-report=term \ # acc. and display coverage + test/functional/ + # export the finalized coverage report to xml + coverage xml [pytest] addopts = --full-trace