From 0f4eced5d78b8da09c0db2db8e8cf4ced1a68f0c Mon Sep 17 00:00:00 2001
From: BIGAUD Nathan <nathan.bigaud@inria.fr>
Date: Tue, 7 Mar 2023 12:15:10 +0100
Subject: [PATCH] Applying corrections after MR review: * Re-introducing
 deprecated `DataField` subclasses with a warning * Other minor changes, to
 `sklearn` and `tf` models, and `DataTypeField`

---
 declearn/data_info/__init__.py      |   6 +-
 declearn/data_info/_fields.py       | 143 ++++++++++++++++++++--------
 declearn/dataset/_base.py           |   4 +-
 declearn/dataset/_inmemory.py       |   2 +-
 declearn/model/sklearn/_sgd.py      |  11 ++-
 declearn/model/tensorflow/_model.py |  25 ++---
 test/functional/test_regression.py  |   2 +-
 7 files changed, 133 insertions(+), 60 deletions(-)

diff --git a/declearn/data_info/__init__.py b/declearn/data_info/__init__.py
index 7dc304e9..4b4ede5f 100644
--- a/declearn/data_info/__init__.py
+++ b/declearn/data_info/__init__.py
@@ -57,7 +57,9 @@ from ._base import (
 )
 from ._fields import (
     ClassesField,
-    SingleInputShapeField,
+    DataTypeField,
+    InputShapeField,
+    NbFeaturesField,
     NbSamplesField,
-    DataTypeField
+    SingleInputShapeField,
 )
diff --git a/declearn/data_info/_fields.py b/declearn/data_info/_fields.py
index 2f27461b..dd0b0bcb 100644
--- a/declearn/data_info/_fields.py
+++ b/declearn/data_info/_fields.py
@@ -17,7 +17,8 @@
 
 """DataInfoField subclasses specifying common 'data_info' metadata fields."""
 
-from typing import Any, ClassVar, Optional, Set, Tuple, Type
+from typing import Any, ClassVar, List, Optional, Set, Tuple, Type
+from warnings import warn
 
 import numpy as np
 
@@ -27,6 +28,8 @@ __all__ = [
     "ClassesField",
     "SingleInputShapeField",
     "NbSamplesField",
+    "NbFeaturesField",
+    "InputShapeField",
 ]
 
 
@@ -91,39 +94,7 @@ class SingleInputShapeField(DataInfoField):
                 f"Cannot combine '{cls.field}': inputs don't have the same"
                 "shape"
             )
-        return unique_shapes[0] # type: ignore
-
-
-# @register_data_info_field
-# class NbFeaturesField(DataInfoField):
-#     """Specifications for 'n_features' data_info field."""
-
-#     field: ClassVar[str] = "n_features"
-#     types: ClassVar[Tuple[Type, ...]] = (int,)
-#     doc: ClassVar[str] = "Number of input features, checked to be equal."
-
-#     @classmethod
-#     def is_valid(
-#         cls,
-#         value: Any,
-#     ) -> bool:
-#         return isinstance(value, int) and (value > 0)
-
-#     @classmethod
-#     def combine(
-#         cls,
-#         *values: Any,
-#     ) -> int:
-#         unique = list(set(values))
-#         if len(unique) != 1:
-#             raise ValueError(
-#                 f"Cannot combine '{cls.field}': non-unique inputs."
-#             )
-#         if not cls.is_valid(unique[0]):
-#             raise ValueError(
-#                 f"Cannot combine '{cls.field}': invalid unique value."
-#             )
-#         return unique[0]
+        return unique_shapes[0]  # type: ignore
 
 
 @register_data_info_field
@@ -165,16 +136,108 @@ class DataTypeField(DataInfoField):
     ) -> bool:
         out = isinstance(value, str)
         if out:
-            # CHECK
             try:
                 np.dtype(value)
-            except TypeError as exp:
-                raise TypeError(
-                    "The received string could not be parsed"
-                    "to a valid array dtype"
-                ) from exp
+            except TypeError:
+                out = False
         return out
 
+    @classmethod
+    def combine(
+        cls,
+        *values: Any,
+    ) -> int:
+        super().combine(*values)
+        unique = list(set(values))
+        if len(unique) != 1:
+            raise ValueError(
+                f"Cannot combine '{cls.field}': non-unique inputs."
+            )
+        if not cls.is_valid(unique[0]):
+            raise ValueError(
+                f"Cannot combine '{cls.field}': invalid unique value."
+            )
+        return unique[0]
+
+
+@register_data_info_field
+class InputShapeField(DataInfoField):
+    """Deprecated specifications for 'input_shape' data_info field."""
+
+    field: ClassVar[str] = "input_shape"
+    types: ClassVar[Tuple[Type, ...]] = (tuple, list)
+    doc: ClassVar[str] = "Input features' batched shape, checked to be equal."
+
+    def __init__(self) -> None:
+        warn(
+            f"{self.__class__.__name__} will be deprecated.",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+
+    @classmethod
+    def is_valid(
+        cls,
+        value: Any,
+    ) -> bool:
+        return (
+            isinstance(value, cls.types)
+            and (len(value) >= 2)
+            and all(isinstance(val, int) or (val is None) for val in value)
+        )
+
+    @classmethod
+    def combine(
+        cls,
+        *values: Any,
+    ) -> List[Optional[int]]:
+        # Type check each and every input shape.
+        super().combine(*values)
+        # Check that all shapes are of same length.
+        unique = list({len(shp) for shp in values})
+        if len(unique) != 1:
+            raise ValueError(
+                f"Cannot combine '{cls.field}': inputs have various lengths."
+            )
+        # Fill-in the unified shape: except all-None or (None or unique) value.
+        # Note: batching dimension is set to None by default (no check).
+        shape = [None] * unique[0]  # type: List[Optional[int]]
+        for i in range(1, unique[0]):
+            val = [shp[i] for shp in values if shp[i] is not None]
+            if not val:  # all None
+                shape[i] = None
+            elif len(set(val)) > 1:
+                raise ValueError(
+                    f"Cannot combine '{cls.field}': provided shapes differ."
+                )
+            else:
+                shape[i] = val[0]
+        # Return the combined shape.
+        return shape
+
+
+@register_data_info_field
+class NbFeaturesField(DataInfoField):
+    """Deprecated specifications for 'n_features' data_info field."""
+
+    field: ClassVar[str] = "n_features"
+    types: ClassVar[Tuple[Type, ...]] = (int,)
+    doc: ClassVar[str] = "Number of input features, checked to be equal."
+
+    def __init__(self) -> None:
+        warn(
+            f"{self.__class__.__name__} will be deprecated.",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+
+    @classmethod
+    def is_valid(
+        cls,
+        value: Any,
+    ) -> bool:
+        return isinstance(value, int) and (value > 0)
+
     @classmethod
     def combine(
         cls,
diff --git a/declearn/dataset/_base.py b/declearn/dataset/_base.py
index 4e201fca..a7e9c167 100644
--- a/declearn/dataset/_base.py
+++ b/declearn/dataset/_base.py
@@ -19,7 +19,7 @@
 
 from abc import ABCMeta, abstractmethod
 from dataclasses import dataclass
-from typing import Any, ClassVar, Iterator, Optional, Set, Tuple, Union, List
+from typing import Any, ClassVar, Iterator, List, Optional, Set, Tuple, Union
 
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
@@ -38,7 +38,7 @@ class DataSpecs:
     """Dataclass to wrap a dataset's metadata."""
 
     n_samples: int
-    single_input_shape: Union[Tuple[int],List[int]]
+    single_input_shape: Union[Tuple[int], List[int]]
     classes: Optional[Set[Any]] = None
     data_type: Optional[str] = None
 
diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py
index 51d33d2a..6efd2714 100644
--- a/declearn/dataset/_inmemory.py
+++ b/declearn/dataset/_inmemory.py
@@ -440,7 +440,7 @@ class InMemoryDataset(Dataset):
         """Return a DataSpecs object describing this dataset."""
         return DataSpecs(
             n_samples=self.feats.shape[0],
-            single_input_shape=self.feats.shape[1:], # type: ignore
+            single_input_shape=self.feats.shape[1:],  # type: ignore
             classes=self.classes,
             data_type=self.data_type,
         )
diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index d4d48a4c..cc5ee6e3 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -136,9 +136,13 @@ class SklearnSGDModel(Model):
             self._model.classes_ = np.array(list(data_info["classes"]))
             n_classes = len(self._model.classes_)
             dim = n_classes if (n_classes > 2) else 1
-            self._model.coef_ = np.zeros(
-                (dim, *data_info["single_input_shape"])
+            if len(data_info["single_input_shape"]) != 1:
+                raise ValueError(
+                    "SklearnSGDModel currently only supports"
+                    "flat, one dimensional features"
                 )
+            feat = data_info["single_input_shape"][0]
+            self._model.coef_ = np.zeros((dim, feat))
             self._model.intercept_ = np.zeros((dim,))
         # SGDRegressor case.
         else:
@@ -380,7 +384,7 @@ class SklearnSGDModel(Model):
     ) -> Callable[[np.ndarray, np.ndarray], np.ndarray]:
         """Return a function to compute point-wise loss for a given batch."""
         # fmt: off
-        # Gather or instantiate a loss function from the wrapped model's specs.
+        # Gather / instantiate a loss function from the wrapped model's specs.
         if hasattr(self._model, "loss_function_"):
             loss_smp = self._model.loss_function_.py_loss
         else:
@@ -399,6 +403,7 @@ class SklearnSGDModel(Model):
         else:
             loss_fn = loss_1d
         return loss_fn
+        # fmt: on
 
     def update_device_policy(
         self,
diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py
index 452608cb..10863a24 100644
--- a/declearn/model/tensorflow/_model.py
+++ b/declearn/model/tensorflow/_model.py
@@ -29,9 +29,11 @@ from declearn.data_info import aggregate_data_info
 from declearn.model._utils import raise_on_stringsets_mismatch
 from declearn.model.api import Model
 from declearn.model.tensorflow._vector import TensorflowVector
-from declearn.model.tensorflow.utils import (build_keras_loss,
-                                             move_layer_to_device,
-                                             select_device)
+from declearn.model.tensorflow.utils import (
+    build_keras_loss,
+    move_layer_to_device,
+    select_device,
+)
 from declearn.typing import Batch
 from declearn.utils import DevicePolicy, get_device_policy, register_type
 
@@ -60,7 +62,7 @@ class TensorflowModel(Model):
     * Note that if the global device-placement policy is updated, this will
       only be propagated to existing instances by manually calling their
       `update_device_policy` method.
-    * You may consult the device policy currently enforced by a TensorflowModel
+    * You may consult the device policy enforced by a TensorflowModel
       instance by accessing its `device_policy` property.
     """
 
@@ -129,20 +131,20 @@ class TensorflowModel(Model):
     def required_data_info(
         self,
     ) -> Set[str]:
-        return set() if self._model.built else {"n_samples",
-                                                "single_input_shape"}
+        return (
+            set() if self._model.built else {"n_samples", "single_input_shape"}
+        )
 
     def initialize(
         self,
         data_info: Dict[str, Any],
     ) -> None:
         if not self._model.built:
-            data_info = aggregate_data_info([data_info],
-                                            self.required_data_info)
+            data_info = aggregate_data_info(
+                [data_info], self.required_data_info
+            )
             with tf.device(self._device):
-                self._model.build(
-                    (data_info['n_samples'], *data_info['single_input_shape'])
-                    )
+                self._model.build((None, *data_info["single_input_shape"]))
 
     def get_config(
         self,
@@ -256,6 +258,7 @@ class TensorflowModel(Model):
             return tf.convert_to_tensor(data)
         # Apply it to the the batched elements.
         return tf.nest.map_structure(convert, batch)
+        # fmt: on
 
     @tf.function  # optimize tensorflow runtime
     def _compute_batch_gradients(
diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py
index 70a52924..d08b5f16 100644
--- a/test/functional/test_regression.py
+++ b/test/functional/test_regression.py
@@ -285,7 +285,7 @@ def test_declearn_baseline(
     d_train = InMemoryDataset(train[0], train[1])
     # Set up a declearn model and a vanilla SGD optimizer.
     model = get_model("numpy")
-    model.initialize({"n_features": d_train.data.shape[1]})
+    model.initialize({"single_input_shape": (d_train.data.shape[1],)})
     opt = Optimizer(lrate=lrate, regularizers=[("lasso", {"alpha": 0.1})])
     # Iteratively train the model, evaluating it after each epoch.
     for _ in range(rounds):
-- 
GitLab