diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py index 583b2893773c81011b32d9dd6d4dea6b91bf0aef..a3e6dfee994f0a15ae9c289d049b73690aacfd1f 100644 --- a/declearn/model/api/_model.py +++ b/declearn/model/api/_model.py @@ -18,7 +18,7 @@ """Model abstraction API.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional, Set, Tuple +from typing import Any, Dict, Generic, Optional, Set, Tuple, TypeVar import numpy as np from typing_extensions import Self # future: import from typing (py >=3.11) @@ -33,8 +33,12 @@ __all__ = [ ] +VectorT = TypeVar("VectorT", bound=Vector) +"""Type-annotation for the Vector subclass proper to a given Model.""" + + @create_types_registry -class Model(metaclass=ABCMeta): +class Model(Generic[VectorT], metaclass=ABCMeta): """Abstract class defining an API to manipulate a ML model. A 'Model' is an abstraction that defines a generic interface @@ -119,7 +123,7 @@ class Model(metaclass=ABCMeta): def get_weights( self, trainable: bool = False, - ) -> Vector: + ) -> VectorT: """Return the model's weights, optionally excluding frozen ones. Parameters @@ -140,7 +144,7 @@ class Model(metaclass=ABCMeta): @abstractmethod def set_weights( self, - weights: Vector, + weights: VectorT, trainable: bool = False, ) -> None: """Assign values to the model's weights. @@ -176,7 +180,7 @@ class Model(metaclass=ABCMeta): self, batch: Batch, max_norm: Optional[float] = None, - ) -> Vector: + ) -> VectorT: """Compute and return gradients computed over a given data batch. Compute the average gradients of the model's loss with respect @@ -204,7 +208,7 @@ class Model(metaclass=ABCMeta): @abstractmethod def apply_updates( self, - updates: Vector, + updates: VectorT, ) -> None: """Apply updates to the model's weights.""" diff --git a/declearn/model/api/_vector.py b/declearn/model/api/_vector.py index 294addd2db35007fb9fec270cf09168a6eb0c65e..f407dc3b3c609509b40c9a4c3ee17c9b3af40ef0 100644 --- a/declearn/model/api/_vector.py +++ b/declearn/model/api/_vector.py @@ -19,7 +19,10 @@ import operator from abc import ABCMeta, abstractmethod -from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union +from typing import ( + # fmt: off + Any, Callable, Dict, Generic, Optional, Set, Tuple, Type, TypeVar, Union +) from numpy.typing import ArrayLike from typing_extensions import Self # future: import from typing (Py>=3.11) @@ -37,10 +40,15 @@ __all__ = [ VECTOR_TYPES = {} # type: Dict[Type[Any], Type[Vector]] +"""Private constant holding registered Vector types.""" + + +T = TypeVar("T") +"""Type-annotation for the data structures proper to a given Vector class.""" @create_types_registry -class Vector(metaclass=ABCMeta): +class Vector(Generic[T], metaclass=ABCMeta): """Abstract class defining an API to manipulate (sets of) data arrays. A Vector is an abstraction used to wrap a collection of data @@ -62,27 +70,27 @@ class Vector(metaclass=ABCMeta): """ @property - def _op_add(self) -> Callable[[Any, Any], Any]: + def _op_add(self) -> Callable[[Any, Any], T]: """Framework-compatible addition operator.""" return operator.add @property - def _op_sub(self) -> Callable[[Any, Any], Any]: + def _op_sub(self) -> Callable[[Any, Any], T]: """Framework-compatible substraction operator.""" return operator.sub @property - def _op_mul(self) -> Callable[[Any, Any], Any]: + def _op_mul(self) -> Callable[[Any, Any], T]: """Framework-compatible multiplication operator.""" return operator.mul @property - def _op_div(self) -> Callable[[Any, Any], Any]: + def _op_div(self) -> Callable[[Any, Any], T]: """Framework-compatible true division operator.""" return operator.truediv @property - def _op_pow(self) -> Callable[[Any, Any], Any]: + def _op_pow(self) -> Callable[[Any, Any], T]: """Framework-compatible power operator.""" return operator.pow @@ -108,13 +116,13 @@ class Vector(metaclass=ABCMeta): def __init__( self, - coefs: Dict[str, Any], + coefs: Dict[str, T], ) -> None: """Instantiate the Vector to wrap a collection of data arrays. Parameters ---------- - coefs: dict[str, any] + coefs: dict[str, <T>] Dict grouping a named collection of data arrays. The supported types of that dict's values depends on the concrete `Vector` subclass being used. @@ -123,7 +131,7 @@ class Vector(metaclass=ABCMeta): @staticmethod def build( - coefs: Dict[str, Any], + coefs: Dict[str, T], ) -> "Vector": """Instantiate a Vector, inferring its exact subtype from coefs'. @@ -136,7 +144,7 @@ class Vector(metaclass=ABCMeta): Parameters ---------- - coefs: dict[str, any] + coefs: dict[str, <T>] Dict grouping a named collection of data arrays, that all belong to the same framework. @@ -189,7 +197,10 @@ class Vector(metaclass=ABCMeta): indexed by the coefficient's name. """ try: - return {key: coef.shape for key, coef in self.coefs.items()} + return { + key: coef.shape # type: ignore # exception caught + for key, coef in self.coefs.items() + } except AttributeError as exc: raise NotImplementedError( "Wrapped coefficients appear not to implement `.shape`.\n" @@ -210,7 +221,10 @@ class Vector(metaclass=ABCMeta): concrete framework of the Vector. """ try: - return {key: str(coef.dtype) for key, coef in self.coefs.items()} + return { + key: str(coef.dtype) # type: ignore # exception caught + for key, coef in self.coefs.items() + } except AttributeError as exc: raise NotImplementedError( "Wrapped coefficients appear not to implement `.dtype`.\n" @@ -261,7 +275,7 @@ class Vector(metaclass=ABCMeta): def apply_func( self, - func: Callable[..., Any], + func: Callable[..., T], *args: Any, **kwargs: Any, ) -> Self: @@ -290,14 +304,15 @@ class Vector(metaclass=ABCMeta): def _apply_operation( self, other: Any, - func: Callable[[Any, Any], Any], + func: Callable[[Any, Any], T], ) -> Self: """Apply an operation to combine this vector with another. Parameters ---------- - other: Vector - Vector with the same names, shapes and dtypes as this one. + other: + Vector with the same names, shapes and dtypes as this one; + or scalar object on which to operate (e.g. a float value). func: function(<T>, <T>) -> <T> Function to be applied to combine the data arrays stored in this vector and the `other` one. diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py index 49da62cfb0e0e42ad69976c8a225c990aff805f9..9fba66ebc6b533867e8f86ef465d1f98a8246117 100644 --- a/declearn/model/sklearn/_np_vec.py +++ b/declearn/model/sklearn/_np_vec.py @@ -58,29 +58,35 @@ class NumpyVector(Vector): """ @property - def _op_add(self) -> Callable[[Any, Any], Any]: + def _op_add(self) -> Callable[[Any, Any], np.ndarray]: return np.add @property - def _op_sub(self) -> Callable[[Any, Any], Any]: + def _op_sub(self) -> Callable[[Any, Any], np.ndarray]: return np.subtract @property - def _op_mul(self) -> Callable[[Any, Any], Any]: + def _op_mul(self) -> Callable[[Any, Any], np.ndarray]: return np.multiply @property - def _op_div(self) -> Callable[[Any, Any], Any]: + def _op_div(self) -> Callable[[Any, Any], np.ndarray]: return np.divide @property - def _op_pow(self) -> Callable[[Any, Any], Any]: + def _op_pow(self) -> Callable[[Any, Any], np.ndarray]: return np.power - def __init__(self, coefs: Dict[str, np.ndarray]) -> None: + def __init__( + self, + coefs: Dict[str, np.ndarray], + ) -> None: super().__init__(coefs) - def __eq__(self, other: Any) -> bool: + def __eq__( + self, + other: Any, + ) -> bool: valid = isinstance(other, NumpyVector) if valid: valid = self.coefs.keys() == other.coefs.keys() diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index 18ad7877d2e9f27fae28e0ffaa8807d8cbfcb4e9..6ab4c88087b4cc775f17aa9dc3b6ec9f9d19b78f 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -282,7 +282,7 @@ class SklearnSGDModel(Model): } return NumpyVector(weights) - def set_weights( # type: ignore # Vector subtype specification + def set_weights( self, weights: NumpyVector, trainable: bool = False, @@ -356,7 +356,7 @@ class SklearnSGDModel(Model): # Compute gradients based on weights' update. return w_srt - w_end - def apply_updates( # type: ignore # Vector subtype specification + def apply_updates( self, updates: NumpyVector, ) -> None: diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index 581c285ca2d0ec80524b392821942d36613e5606..59260eaededbf9182178d82ec282c68964867e27 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -183,7 +183,7 @@ class TensorflowModel(Model): ) return TensorflowVector({var.name: var.value() for var in variables}) - def set_weights( # type: ignore # Vector subtype specification + def set_weights( self, weights: TensorflowVector, trainable: bool = False, @@ -319,7 +319,7 @@ class TensorflowModel(Model): outp.append(tf.reduce_mean(grad * s_wght, axis=0)) return outp - def apply_updates( # type: ignore # Vector subtype specification + def apply_updates( self, updates: TensorflowVector, ) -> None: diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 7491c3f53b444f8216143a39ce0d15249c810ebc..f8a5dd6f1396d6d6909edbf070ee616e6cc3a6ec 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -168,7 +168,7 @@ class TorchModel(Model): # Note: calling `tensor.clone()` to return a copy rather than a view. return TorchVector({k: t.detach().clone() for k, t in weights.items()}) - def set_weights( # type: ignore # Vector subtype specification + def set_weights( self, weights: TorchVector, trainable: bool = False, @@ -378,7 +378,7 @@ class TorchModel(Model): return grads_fn return functorch.compile.aot_function(grads_fn, functorch.compile.nop) - def apply_updates( # type: ignore # Vector subtype specification + def apply_updates( self, updates: TorchVector, ) -> None: diff --git a/declearn/optimizer/_base.py b/declearn/optimizer/_base.py index 5e14779cfa5ce52a16fb4143b9a61064218b5aa2..7480d8f8b077d0cf79ca369ce12c65e516182675 100644 --- a/declearn/optimizer/_base.py +++ b/declearn/optimizer/_base.py @@ -17,7 +17,10 @@ """Base class to define gradient-descent-based optimizers.""" -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + # fmt: off + Any, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union +) from typing_extensions import Self # future: import from typing (py >=3.11) @@ -31,6 +34,9 @@ __all__ = [ ] +T = TypeVar("T") + + class Optimizer: """Base class to define gradient-descent-based optimizers. @@ -255,9 +261,9 @@ class Optimizer: def compute_updates_from_gradients( self, - model: Model, - gradients: Vector, - ) -> Vector: + model: Model[Vector[T]], + gradients: Vector[T], + ) -> Vector[T]: """Compute and return model updates based on pre-computed gradients. Parameters @@ -393,8 +399,8 @@ class Optimizer: def apply_gradients( self, - model: Model, - gradients: Vector, + model: Model[Vector[T]], + gradients: Vector[T], ) -> None: """Compute and apply model updates based on pre-computed gradients. diff --git a/declearn/optimizer/modules/_api.py b/declearn/optimizer/modules/_api.py index 197346e0bc425f94e053cf336d3eb05bfd5e5468..55d3f37cfa6c049f2628d8581e2803590a4d6b2b 100644 --- a/declearn/optimizer/modules/_api.py +++ b/declearn/optimizer/modules/_api.py @@ -18,7 +18,7 @@ """Base API for plug-in gradients-alteration algorithms.""" from abc import ABCMeta, abstractmethod -from typing import Any, ClassVar, Dict, Optional +from typing import Any, ClassVar, Dict, Optional, TypeVar from typing_extensions import Self # future: import from typing (py >=3.11) @@ -34,6 +34,9 @@ __all__ = [ ] +T = TypeVar("T") + + @create_types_registry class OptiModule(metaclass=ABCMeta): """Abstract class defining an API to implement gradients adaptation tools. @@ -117,8 +120,8 @@ class OptiModule(metaclass=ABCMeta): @abstractmethod def run( self, - gradients: Vector, - ) -> Vector: + gradients: Vector[T], + ) -> Vector[T]: """Apply the module's algorithm to input gradients. Please refer to the module's main docstring for details diff --git a/declearn/optimizer/regularizers/_api.py b/declearn/optimizer/regularizers/_api.py index 1a9dfa5bc2151971a101c48e11ce7c480d4ec7ca..187f5eb5ef2875b11f17da7dfda697fe4c50c38c 100644 --- a/declearn/optimizer/regularizers/_api.py +++ b/declearn/optimizer/regularizers/_api.py @@ -18,7 +18,7 @@ """Base API for loss regularization optimizer plug-ins.""" from abc import ABCMeta, abstractmethod -from typing import Any, ClassVar, Dict +from typing import Any, ClassVar, Dict, TypeVar from typing_extensions import Self # future: import from typing (py >=3.11) @@ -34,6 +34,9 @@ __all__ = [ ] +T = TypeVar("T") + + @create_types_registry class Regularizer(metaclass=ABCMeta): """Abstract class defining an API to implement loss-regularizers. @@ -115,9 +118,9 @@ class Regularizer(metaclass=ABCMeta): @abstractmethod def run( self, - gradients: Vector, - weights: Vector, - ) -> Vector: + gradients: Vector[T], + weights: Vector[T], + ) -> Vector[T]: """Compute and add the regularization term's derivative to gradients. Parameters