From 0f4eced5d78b8da09c0db2db8e8cf4ced1a68f0c Mon Sep 17 00:00:00 2001 From: BIGAUD Nathan <nathan.bigaud@inria.fr> Date: Tue, 7 Mar 2023 12:15:10 +0100 Subject: [PATCH] Applying corrections after MR review: * Re-introducing deprecated `DataField` subclasses with a warning * Other minor changes, to `sklearn` and `tf` models, and `DataTypeField` --- declearn/data_info/__init__.py | 6 +- declearn/data_info/_fields.py | 143 ++++++++++++++++++++-------- declearn/dataset/_base.py | 4 +- declearn/dataset/_inmemory.py | 2 +- declearn/model/sklearn/_sgd.py | 11 ++- declearn/model/tensorflow/_model.py | 25 ++--- test/functional/test_regression.py | 2 +- 7 files changed, 133 insertions(+), 60 deletions(-) diff --git a/declearn/data_info/__init__.py b/declearn/data_info/__init__.py index 7dc304e9..4b4ede5f 100644 --- a/declearn/data_info/__init__.py +++ b/declearn/data_info/__init__.py @@ -57,7 +57,9 @@ from ._base import ( ) from ._fields import ( ClassesField, - SingleInputShapeField, + DataTypeField, + InputShapeField, + NbFeaturesField, NbSamplesField, - DataTypeField + SingleInputShapeField, ) diff --git a/declearn/data_info/_fields.py b/declearn/data_info/_fields.py index 2f27461b..dd0b0bcb 100644 --- a/declearn/data_info/_fields.py +++ b/declearn/data_info/_fields.py @@ -17,7 +17,8 @@ """DataInfoField subclasses specifying common 'data_info' metadata fields.""" -from typing import Any, ClassVar, Optional, Set, Tuple, Type +from typing import Any, ClassVar, List, Optional, Set, Tuple, Type +from warnings import warn import numpy as np @@ -27,6 +28,8 @@ __all__ = [ "ClassesField", "SingleInputShapeField", "NbSamplesField", + "NbFeaturesField", + "InputShapeField", ] @@ -91,39 +94,7 @@ class SingleInputShapeField(DataInfoField): f"Cannot combine '{cls.field}': inputs don't have the same" "shape" ) - return unique_shapes[0] # type: ignore - - -# @register_data_info_field -# class NbFeaturesField(DataInfoField): -# """Specifications for 'n_features' data_info field.""" - -# field: ClassVar[str] = "n_features" -# types: ClassVar[Tuple[Type, ...]] = (int,) -# doc: ClassVar[str] = "Number of input features, checked to be equal." - -# @classmethod -# def is_valid( -# cls, -# value: Any, -# ) -> bool: -# return isinstance(value, int) and (value > 0) - -# @classmethod -# def combine( -# cls, -# *values: Any, -# ) -> int: -# unique = list(set(values)) -# if len(unique) != 1: -# raise ValueError( -# f"Cannot combine '{cls.field}': non-unique inputs." -# ) -# if not cls.is_valid(unique[0]): -# raise ValueError( -# f"Cannot combine '{cls.field}': invalid unique value." -# ) -# return unique[0] + return unique_shapes[0] # type: ignore @register_data_info_field @@ -165,16 +136,108 @@ class DataTypeField(DataInfoField): ) -> bool: out = isinstance(value, str) if out: - # CHECK try: np.dtype(value) - except TypeError as exp: - raise TypeError( - "The received string could not be parsed" - "to a valid array dtype" - ) from exp + except TypeError: + out = False return out + @classmethod + def combine( + cls, + *values: Any, + ) -> int: + super().combine(*values) + unique = list(set(values)) + if len(unique) != 1: + raise ValueError( + f"Cannot combine '{cls.field}': non-unique inputs." + ) + if not cls.is_valid(unique[0]): + raise ValueError( + f"Cannot combine '{cls.field}': invalid unique value." + ) + return unique[0] + + +@register_data_info_field +class InputShapeField(DataInfoField): + """Deprecated specifications for 'input_shape' data_info field.""" + + field: ClassVar[str] = "input_shape" + types: ClassVar[Tuple[Type, ...]] = (tuple, list) + doc: ClassVar[str] = "Input features' batched shape, checked to be equal." + + def __init__(self) -> None: + warn( + f"{self.__class__.__name__} will be deprecated.", + DeprecationWarning, + stacklevel=2, + ) + + @classmethod + def is_valid( + cls, + value: Any, + ) -> bool: + return ( + isinstance(value, cls.types) + and (len(value) >= 2) + and all(isinstance(val, int) or (val is None) for val in value) + ) + + @classmethod + def combine( + cls, + *values: Any, + ) -> List[Optional[int]]: + # Type check each and every input shape. + super().combine(*values) + # Check that all shapes are of same length. + unique = list({len(shp) for shp in values}) + if len(unique) != 1: + raise ValueError( + f"Cannot combine '{cls.field}': inputs have various lengths." + ) + # Fill-in the unified shape: except all-None or (None or unique) value. + # Note: batching dimension is set to None by default (no check). + shape = [None] * unique[0] # type: List[Optional[int]] + for i in range(1, unique[0]): + val = [shp[i] for shp in values if shp[i] is not None] + if not val: # all None + shape[i] = None + elif len(set(val)) > 1: + raise ValueError( + f"Cannot combine '{cls.field}': provided shapes differ." + ) + else: + shape[i] = val[0] + # Return the combined shape. + return shape + + +@register_data_info_field +class NbFeaturesField(DataInfoField): + """Deprecated specifications for 'n_features' data_info field.""" + + field: ClassVar[str] = "n_features" + types: ClassVar[Tuple[Type, ...]] = (int,) + doc: ClassVar[str] = "Number of input features, checked to be equal." + + def __init__(self) -> None: + warn( + f"{self.__class__.__name__} will be deprecated.", + DeprecationWarning, + stacklevel=2, + ) + + @classmethod + def is_valid( + cls, + value: Any, + ) -> bool: + return isinstance(value, int) and (value > 0) + @classmethod def combine( cls, diff --git a/declearn/dataset/_base.py b/declearn/dataset/_base.py index 4e201fca..a7e9c167 100644 --- a/declearn/dataset/_base.py +++ b/declearn/dataset/_base.py @@ -19,7 +19,7 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, ClassVar, Iterator, Optional, Set, Tuple, Union, List +from typing import Any, ClassVar, Iterator, List, Optional, Set, Tuple, Union from typing_extensions import Self # future: import from typing (py >=3.11) @@ -38,7 +38,7 @@ class DataSpecs: """Dataclass to wrap a dataset's metadata.""" n_samples: int - single_input_shape: Union[Tuple[int],List[int]] + single_input_shape: Union[Tuple[int], List[int]] classes: Optional[Set[Any]] = None data_type: Optional[str] = None diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 51d33d2a..6efd2714 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -440,7 +440,7 @@ class InMemoryDataset(Dataset): """Return a DataSpecs object describing this dataset.""" return DataSpecs( n_samples=self.feats.shape[0], - single_input_shape=self.feats.shape[1:], # type: ignore + single_input_shape=self.feats.shape[1:], # type: ignore classes=self.classes, data_type=self.data_type, ) diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index d4d48a4c..cc5ee6e3 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -136,9 +136,13 @@ class SklearnSGDModel(Model): self._model.classes_ = np.array(list(data_info["classes"])) n_classes = len(self._model.classes_) dim = n_classes if (n_classes > 2) else 1 - self._model.coef_ = np.zeros( - (dim, *data_info["single_input_shape"]) + if len(data_info["single_input_shape"]) != 1: + raise ValueError( + "SklearnSGDModel currently only supports" + "flat, one dimensional features" ) + feat = data_info["single_input_shape"][0] + self._model.coef_ = np.zeros((dim, feat)) self._model.intercept_ = np.zeros((dim,)) # SGDRegressor case. else: @@ -380,7 +384,7 @@ class SklearnSGDModel(Model): ) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: """Return a function to compute point-wise loss for a given batch.""" # fmt: off - # Gather or instantiate a loss function from the wrapped model's specs. + # Gather / instantiate a loss function from the wrapped model's specs. if hasattr(self._model, "loss_function_"): loss_smp = self._model.loss_function_.py_loss else: @@ -399,6 +403,7 @@ class SklearnSGDModel(Model): else: loss_fn = loss_1d return loss_fn + # fmt: on def update_device_policy( self, diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index 452608cb..10863a24 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -29,9 +29,11 @@ from declearn.data_info import aggregate_data_info from declearn.model._utils import raise_on_stringsets_mismatch from declearn.model.api import Model from declearn.model.tensorflow._vector import TensorflowVector -from declearn.model.tensorflow.utils import (build_keras_loss, - move_layer_to_device, - select_device) +from declearn.model.tensorflow.utils import ( + build_keras_loss, + move_layer_to_device, + select_device, +) from declearn.typing import Batch from declearn.utils import DevicePolicy, get_device_policy, register_type @@ -60,7 +62,7 @@ class TensorflowModel(Model): * Note that if the global device-placement policy is updated, this will only be propagated to existing instances by manually calling their `update_device_policy` method. - * You may consult the device policy currently enforced by a TensorflowModel + * You may consult the device policy enforced by a TensorflowModel instance by accessing its `device_policy` property. """ @@ -129,20 +131,20 @@ class TensorflowModel(Model): def required_data_info( self, ) -> Set[str]: - return set() if self._model.built else {"n_samples", - "single_input_shape"} + return ( + set() if self._model.built else {"n_samples", "single_input_shape"} + ) def initialize( self, data_info: Dict[str, Any], ) -> None: if not self._model.built: - data_info = aggregate_data_info([data_info], - self.required_data_info) + data_info = aggregate_data_info( + [data_info], self.required_data_info + ) with tf.device(self._device): - self._model.build( - (data_info['n_samples'], *data_info['single_input_shape']) - ) + self._model.build((None, *data_info["single_input_shape"])) def get_config( self, @@ -256,6 +258,7 @@ class TensorflowModel(Model): return tf.convert_to_tensor(data) # Apply it to the the batched elements. return tf.nest.map_structure(convert, batch) + # fmt: on @tf.function # optimize tensorflow runtime def _compute_batch_gradients( diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py index 70a52924..d08b5f16 100644 --- a/test/functional/test_regression.py +++ b/test/functional/test_regression.py @@ -285,7 +285,7 @@ def test_declearn_baseline( d_train = InMemoryDataset(train[0], train[1]) # Set up a declearn model and a vanilla SGD optimizer. model = get_model("numpy") - model.initialize({"n_features": d_train.data.shape[1]}) + model.initialize({"single_input_shape": (d_train.data.shape[1],)}) opt = Optimizer(lrate=lrate, regularizers=[("lasso", {"alpha": 0.1})]) # Iteratively train the model, evaluating it after each epoch. for _ in range(rounds): -- GitLab