Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 7f3327b1 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Enhance typing of 'Model' and 'Vector' using 'typing.Generic'.

- Make `Vector` a `Generic[T]` to indicate that stored data arrays
  should have coherent types.
- Make `Model` a `Generic[VectorT]` to indicate that a given model
  class is associated with a unique vector class.
- In `Optimizer` and plug-ins ABCs, specify that inputs and outputs
  should be coherent using `Vector[T]` and `Model[Vector[T]]` type-
  hints.
- In concrete plug-in classes, leave the existing code as-is so as
  not to make it harder to read.
parent a748ae71
No related branches found
No related tags found
1 merge request!44Minor gardening around the package
......@@ -18,7 +18,7 @@
"""Model abstraction API."""
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional, Set, Tuple
from typing import Any, Dict, Generic, Optional, Set, Tuple, TypeVar
import numpy as np
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -33,8 +33,12 @@ __all__ = [
]
VectorT = TypeVar("VectorT", bound=Vector)
"""Type-annotation for the Vector subclass proper to a given Model."""
@create_types_registry
class Model(metaclass=ABCMeta):
class Model(Generic[VectorT], metaclass=ABCMeta):
"""Abstract class defining an API to manipulate a ML model.
A 'Model' is an abstraction that defines a generic interface
......@@ -119,7 +123,7 @@ class Model(metaclass=ABCMeta):
def get_weights(
self,
trainable: bool = False,
) -> Vector:
) -> VectorT:
"""Return the model's weights, optionally excluding frozen ones.
Parameters
......@@ -140,7 +144,7 @@ class Model(metaclass=ABCMeta):
@abstractmethod
def set_weights(
self,
weights: Vector,
weights: VectorT,
trainable: bool = False,
) -> None:
"""Assign values to the model's weights.
......@@ -176,7 +180,7 @@ class Model(metaclass=ABCMeta):
self,
batch: Batch,
max_norm: Optional[float] = None,
) -> Vector:
) -> VectorT:
"""Compute and return gradients computed over a given data batch.
Compute the average gradients of the model's loss with respect
......@@ -204,7 +208,7 @@ class Model(metaclass=ABCMeta):
@abstractmethod
def apply_updates(
self,
updates: Vector,
updates: VectorT,
) -> None:
"""Apply updates to the model's weights."""
......
......@@ -19,7 +19,10 @@
import operator
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union
from typing import (
# fmt: off
Any, Callable, Dict, Generic, Optional, Set, Tuple, Type, TypeVar, Union
)
from numpy.typing import ArrayLike
from typing_extensions import Self # future: import from typing (Py>=3.11)
......@@ -37,10 +40,15 @@ __all__ = [
VECTOR_TYPES = {} # type: Dict[Type[Any], Type[Vector]]
"""Private constant holding registered Vector types."""
T = TypeVar("T")
"""Type-annotation for the data structures proper to a given Vector class."""
@create_types_registry
class Vector(metaclass=ABCMeta):
class Vector(Generic[T], metaclass=ABCMeta):
"""Abstract class defining an API to manipulate (sets of) data arrays.
A Vector is an abstraction used to wrap a collection of data
......@@ -62,27 +70,27 @@ class Vector(metaclass=ABCMeta):
"""
@property
def _op_add(self) -> Callable[[Any, Any], Any]:
def _op_add(self) -> Callable[[Any, Any], T]:
"""Framework-compatible addition operator."""
return operator.add
@property
def _op_sub(self) -> Callable[[Any, Any], Any]:
def _op_sub(self) -> Callable[[Any, Any], T]:
"""Framework-compatible substraction operator."""
return operator.sub
@property
def _op_mul(self) -> Callable[[Any, Any], Any]:
def _op_mul(self) -> Callable[[Any, Any], T]:
"""Framework-compatible multiplication operator."""
return operator.mul
@property
def _op_div(self) -> Callable[[Any, Any], Any]:
def _op_div(self) -> Callable[[Any, Any], T]:
"""Framework-compatible true division operator."""
return operator.truediv
@property
def _op_pow(self) -> Callable[[Any, Any], Any]:
def _op_pow(self) -> Callable[[Any, Any], T]:
"""Framework-compatible power operator."""
return operator.pow
......@@ -108,13 +116,13 @@ class Vector(metaclass=ABCMeta):
def __init__(
self,
coefs: Dict[str, Any],
coefs: Dict[str, T],
) -> None:
"""Instantiate the Vector to wrap a collection of data arrays.
Parameters
----------
coefs: dict[str, any]
coefs: dict[str, <T>]
Dict grouping a named collection of data arrays.
The supported types of that dict's values depends
on the concrete `Vector` subclass being used.
......@@ -123,7 +131,7 @@ class Vector(metaclass=ABCMeta):
@staticmethod
def build(
coefs: Dict[str, Any],
coefs: Dict[str, T],
) -> "Vector":
"""Instantiate a Vector, inferring its exact subtype from coefs'.
......@@ -136,7 +144,7 @@ class Vector(metaclass=ABCMeta):
Parameters
----------
coefs: dict[str, any]
coefs: dict[str, <T>]
Dict grouping a named collection of data arrays, that
all belong to the same framework.
......@@ -189,7 +197,10 @@ class Vector(metaclass=ABCMeta):
indexed by the coefficient's name.
"""
try:
return {key: coef.shape for key, coef in self.coefs.items()}
return {
key: coef.shape # type: ignore # exception caught
for key, coef in self.coefs.items()
}
except AttributeError as exc:
raise NotImplementedError(
"Wrapped coefficients appear not to implement `.shape`.\n"
......@@ -210,7 +221,10 @@ class Vector(metaclass=ABCMeta):
concrete framework of the Vector.
"""
try:
return {key: str(coef.dtype) for key, coef in self.coefs.items()}
return {
key: str(coef.dtype) # type: ignore # exception caught
for key, coef in self.coefs.items()
}
except AttributeError as exc:
raise NotImplementedError(
"Wrapped coefficients appear not to implement `.dtype`.\n"
......@@ -261,7 +275,7 @@ class Vector(metaclass=ABCMeta):
def apply_func(
self,
func: Callable[..., Any],
func: Callable[..., T],
*args: Any,
**kwargs: Any,
) -> Self:
......@@ -290,14 +304,15 @@ class Vector(metaclass=ABCMeta):
def _apply_operation(
self,
other: Any,
func: Callable[[Any, Any], Any],
func: Callable[[Any, Any], T],
) -> Self:
"""Apply an operation to combine this vector with another.
Parameters
----------
other: Vector
Vector with the same names, shapes and dtypes as this one.
other:
Vector with the same names, shapes and dtypes as this one;
or scalar object on which to operate (e.g. a float value).
func: function(<T>, <T>) -> <T>
Function to be applied to combine the data arrays stored
in this vector and the `other` one.
......
......@@ -58,29 +58,35 @@ class NumpyVector(Vector):
"""
@property
def _op_add(self) -> Callable[[Any, Any], Any]:
def _op_add(self) -> Callable[[Any, Any], np.ndarray]:
return np.add
@property
def _op_sub(self) -> Callable[[Any, Any], Any]:
def _op_sub(self) -> Callable[[Any, Any], np.ndarray]:
return np.subtract
@property
def _op_mul(self) -> Callable[[Any, Any], Any]:
def _op_mul(self) -> Callable[[Any, Any], np.ndarray]:
return np.multiply
@property
def _op_div(self) -> Callable[[Any, Any], Any]:
def _op_div(self) -> Callable[[Any, Any], np.ndarray]:
return np.divide
@property
def _op_pow(self) -> Callable[[Any, Any], Any]:
def _op_pow(self) -> Callable[[Any, Any], np.ndarray]:
return np.power
def __init__(self, coefs: Dict[str, np.ndarray]) -> None:
def __init__(
self,
coefs: Dict[str, np.ndarray],
) -> None:
super().__init__(coefs)
def __eq__(self, other: Any) -> bool:
def __eq__(
self,
other: Any,
) -> bool:
valid = isinstance(other, NumpyVector)
if valid:
valid = self.coefs.keys() == other.coefs.keys()
......
......@@ -282,7 +282,7 @@ class SklearnSGDModel(Model):
}
return NumpyVector(weights)
def set_weights( # type: ignore # Vector subtype specification
def set_weights(
self,
weights: NumpyVector,
trainable: bool = False,
......@@ -356,7 +356,7 @@ class SklearnSGDModel(Model):
# Compute gradients based on weights' update.
return w_srt - w_end
def apply_updates( # type: ignore # Vector subtype specification
def apply_updates(
self,
updates: NumpyVector,
) -> None:
......
......@@ -183,7 +183,7 @@ class TensorflowModel(Model):
)
return TensorflowVector({var.name: var.value() for var in variables})
def set_weights( # type: ignore # Vector subtype specification
def set_weights(
self,
weights: TensorflowVector,
trainable: bool = False,
......@@ -319,7 +319,7 @@ class TensorflowModel(Model):
outp.append(tf.reduce_mean(grad * s_wght, axis=0))
return outp
def apply_updates( # type: ignore # Vector subtype specification
def apply_updates(
self,
updates: TensorflowVector,
) -> None:
......
......@@ -168,7 +168,7 @@ class TorchModel(Model):
# Note: calling `tensor.clone()` to return a copy rather than a view.
return TorchVector({k: t.detach().clone() for k, t in weights.items()})
def set_weights( # type: ignore # Vector subtype specification
def set_weights(
self,
weights: TorchVector,
trainable: bool = False,
......@@ -378,7 +378,7 @@ class TorchModel(Model):
return grads_fn
return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
def apply_updates( # type: ignore # Vector subtype specification
def apply_updates(
self,
updates: TorchVector,
) -> None:
......
......@@ -17,7 +17,10 @@
"""Base class to define gradient-descent-based optimizers."""
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import (
# fmt: off
Any, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
)
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -31,6 +34,9 @@ __all__ = [
]
T = TypeVar("T")
class Optimizer:
"""Base class to define gradient-descent-based optimizers.
......@@ -255,9 +261,9 @@ class Optimizer:
def compute_updates_from_gradients(
self,
model: Model,
gradients: Vector,
) -> Vector:
model: Model[Vector[T]],
gradients: Vector[T],
) -> Vector[T]:
"""Compute and return model updates based on pre-computed gradients.
Parameters
......@@ -393,8 +399,8 @@ class Optimizer:
def apply_gradients(
self,
model: Model,
gradients: Vector,
model: Model[Vector[T]],
gradients: Vector[T],
) -> None:
"""Compute and apply model updates based on pre-computed gradients.
......
......@@ -18,7 +18,7 @@
"""Base API for plug-in gradients-alteration algorithms."""
from abc import ABCMeta, abstractmethod
from typing import Any, ClassVar, Dict, Optional
from typing import Any, ClassVar, Dict, Optional, TypeVar
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -34,6 +34,9 @@ __all__ = [
]
T = TypeVar("T")
@create_types_registry
class OptiModule(metaclass=ABCMeta):
"""Abstract class defining an API to implement gradients adaptation tools.
......@@ -117,8 +120,8 @@ class OptiModule(metaclass=ABCMeta):
@abstractmethod
def run(
self,
gradients: Vector,
) -> Vector:
gradients: Vector[T],
) -> Vector[T]:
"""Apply the module's algorithm to input gradients.
Please refer to the module's main docstring for details
......
......@@ -18,7 +18,7 @@
"""Base API for loss regularization optimizer plug-ins."""
from abc import ABCMeta, abstractmethod
from typing import Any, ClassVar, Dict
from typing import Any, ClassVar, Dict, TypeVar
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -34,6 +34,9 @@ __all__ = [
]
T = TypeVar("T")
@create_types_registry
class Regularizer(metaclass=ABCMeta):
"""Abstract class defining an API to implement loss-regularizers.
......@@ -115,9 +118,9 @@ class Regularizer(metaclass=ABCMeta):
@abstractmethod
def run(
self,
gradients: Vector,
weights: Vector,
) -> Vector:
gradients: Vector[T],
weights: Vector[T],
) -> Vector[T]:
"""Compute and add the regularization term's derivative to gradients.
Parameters
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment