Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 0f4eced5 authored by BIGAUD Nathan's avatar BIGAUD Nathan
Browse files

Applying corrections after MR review:

* Re-introducing deprecated `DataField` subclasses with a warning
* Other minor changes, to `sklearn` and `tf` models, and `DataTypeField`
parent 83204307
No related branches found
No related tags found
1 merge request!36Adding data_type to DataSpecs
Pipeline #767528 waiting for manual action
......@@ -57,7 +57,9 @@ from ._base import (
)
from ._fields import (
ClassesField,
SingleInputShapeField,
DataTypeField,
InputShapeField,
NbFeaturesField,
NbSamplesField,
DataTypeField
SingleInputShapeField,
)
......@@ -17,7 +17,8 @@
"""DataInfoField subclasses specifying common 'data_info' metadata fields."""
from typing import Any, ClassVar, Optional, Set, Tuple, Type
from typing import Any, ClassVar, List, Optional, Set, Tuple, Type
from warnings import warn
import numpy as np
......@@ -27,6 +28,8 @@ __all__ = [
"ClassesField",
"SingleInputShapeField",
"NbSamplesField",
"NbFeaturesField",
"InputShapeField",
]
......@@ -91,39 +94,7 @@ class SingleInputShapeField(DataInfoField):
f"Cannot combine '{cls.field}': inputs don't have the same"
"shape"
)
return unique_shapes[0] # type: ignore
# @register_data_info_field
# class NbFeaturesField(DataInfoField):
# """Specifications for 'n_features' data_info field."""
# field: ClassVar[str] = "n_features"
# types: ClassVar[Tuple[Type, ...]] = (int,)
# doc: ClassVar[str] = "Number of input features, checked to be equal."
# @classmethod
# def is_valid(
# cls,
# value: Any,
# ) -> bool:
# return isinstance(value, int) and (value > 0)
# @classmethod
# def combine(
# cls,
# *values: Any,
# ) -> int:
# unique = list(set(values))
# if len(unique) != 1:
# raise ValueError(
# f"Cannot combine '{cls.field}': non-unique inputs."
# )
# if not cls.is_valid(unique[0]):
# raise ValueError(
# f"Cannot combine '{cls.field}': invalid unique value."
# )
# return unique[0]
return unique_shapes[0] # type: ignore
@register_data_info_field
......@@ -165,16 +136,108 @@ class DataTypeField(DataInfoField):
) -> bool:
out = isinstance(value, str)
if out:
# CHECK
try:
np.dtype(value)
except TypeError as exp:
raise TypeError(
"The received string could not be parsed"
"to a valid array dtype"
) from exp
except TypeError:
out = False
return out
@classmethod
def combine(
cls,
*values: Any,
) -> int:
super().combine(*values)
unique = list(set(values))
if len(unique) != 1:
raise ValueError(
f"Cannot combine '{cls.field}': non-unique inputs."
)
if not cls.is_valid(unique[0]):
raise ValueError(
f"Cannot combine '{cls.field}': invalid unique value."
)
return unique[0]
@register_data_info_field
class InputShapeField(DataInfoField):
"""Deprecated specifications for 'input_shape' data_info field."""
field: ClassVar[str] = "input_shape"
types: ClassVar[Tuple[Type, ...]] = (tuple, list)
doc: ClassVar[str] = "Input features' batched shape, checked to be equal."
def __init__(self) -> None:
warn(
f"{self.__class__.__name__} will be deprecated.",
DeprecationWarning,
stacklevel=2,
)
@classmethod
def is_valid(
cls,
value: Any,
) -> bool:
return (
isinstance(value, cls.types)
and (len(value) >= 2)
and all(isinstance(val, int) or (val is None) for val in value)
)
@classmethod
def combine(
cls,
*values: Any,
) -> List[Optional[int]]:
# Type check each and every input shape.
super().combine(*values)
# Check that all shapes are of same length.
unique = list({len(shp) for shp in values})
if len(unique) != 1:
raise ValueError(
f"Cannot combine '{cls.field}': inputs have various lengths."
)
# Fill-in the unified shape: except all-None or (None or unique) value.
# Note: batching dimension is set to None by default (no check).
shape = [None] * unique[0] # type: List[Optional[int]]
for i in range(1, unique[0]):
val = [shp[i] for shp in values if shp[i] is not None]
if not val: # all None
shape[i] = None
elif len(set(val)) > 1:
raise ValueError(
f"Cannot combine '{cls.field}': provided shapes differ."
)
else:
shape[i] = val[0]
# Return the combined shape.
return shape
@register_data_info_field
class NbFeaturesField(DataInfoField):
"""Deprecated specifications for 'n_features' data_info field."""
field: ClassVar[str] = "n_features"
types: ClassVar[Tuple[Type, ...]] = (int,)
doc: ClassVar[str] = "Number of input features, checked to be equal."
def __init__(self) -> None:
warn(
f"{self.__class__.__name__} will be deprecated.",
DeprecationWarning,
stacklevel=2,
)
@classmethod
def is_valid(
cls,
value: Any,
) -> bool:
return isinstance(value, int) and (value > 0)
@classmethod
def combine(
cls,
......
......@@ -19,7 +19,7 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, ClassVar, Iterator, Optional, Set, Tuple, Union, List
from typing import Any, ClassVar, Iterator, List, Optional, Set, Tuple, Union
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -38,7 +38,7 @@ class DataSpecs:
"""Dataclass to wrap a dataset's metadata."""
n_samples: int
single_input_shape: Union[Tuple[int],List[int]]
single_input_shape: Union[Tuple[int], List[int]]
classes: Optional[Set[Any]] = None
data_type: Optional[str] = None
......
......@@ -440,7 +440,7 @@ class InMemoryDataset(Dataset):
"""Return a DataSpecs object describing this dataset."""
return DataSpecs(
n_samples=self.feats.shape[0],
single_input_shape=self.feats.shape[1:], # type: ignore
single_input_shape=self.feats.shape[1:], # type: ignore
classes=self.classes,
data_type=self.data_type,
)
......
......@@ -136,9 +136,13 @@ class SklearnSGDModel(Model):
self._model.classes_ = np.array(list(data_info["classes"]))
n_classes = len(self._model.classes_)
dim = n_classes if (n_classes > 2) else 1
self._model.coef_ = np.zeros(
(dim, *data_info["single_input_shape"])
if len(data_info["single_input_shape"]) != 1:
raise ValueError(
"SklearnSGDModel currently only supports"
"flat, one dimensional features"
)
feat = data_info["single_input_shape"][0]
self._model.coef_ = np.zeros((dim, feat))
self._model.intercept_ = np.zeros((dim,))
# SGDRegressor case.
else:
......@@ -380,7 +384,7 @@ class SklearnSGDModel(Model):
) -> Callable[[np.ndarray, np.ndarray], np.ndarray]:
"""Return a function to compute point-wise loss for a given batch."""
# fmt: off
# Gather or instantiate a loss function from the wrapped model's specs.
# Gather / instantiate a loss function from the wrapped model's specs.
if hasattr(self._model, "loss_function_"):
loss_smp = self._model.loss_function_.py_loss
else:
......@@ -399,6 +403,7 @@ class SklearnSGDModel(Model):
else:
loss_fn = loss_1d
return loss_fn
# fmt: on
def update_device_policy(
self,
......
......@@ -29,9 +29,11 @@ from declearn.data_info import aggregate_data_info
from declearn.model._utils import raise_on_stringsets_mismatch
from declearn.model.api import Model
from declearn.model.tensorflow._vector import TensorflowVector
from declearn.model.tensorflow.utils import (build_keras_loss,
move_layer_to_device,
select_device)
from declearn.model.tensorflow.utils import (
build_keras_loss,
move_layer_to_device,
select_device,
)
from declearn.typing import Batch
from declearn.utils import DevicePolicy, get_device_policy, register_type
......@@ -60,7 +62,7 @@ class TensorflowModel(Model):
* Note that if the global device-placement policy is updated, this will
only be propagated to existing instances by manually calling their
`update_device_policy` method.
* You may consult the device policy currently enforced by a TensorflowModel
* You may consult the device policy enforced by a TensorflowModel
instance by accessing its `device_policy` property.
"""
......@@ -129,20 +131,20 @@ class TensorflowModel(Model):
def required_data_info(
self,
) -> Set[str]:
return set() if self._model.built else {"n_samples",
"single_input_shape"}
return (
set() if self._model.built else {"n_samples", "single_input_shape"}
)
def initialize(
self,
data_info: Dict[str, Any],
) -> None:
if not self._model.built:
data_info = aggregate_data_info([data_info],
self.required_data_info)
data_info = aggregate_data_info(
[data_info], self.required_data_info
)
with tf.device(self._device):
self._model.build(
(data_info['n_samples'], *data_info['single_input_shape'])
)
self._model.build((None, *data_info["single_input_shape"]))
def get_config(
self,
......@@ -256,6 +258,7 @@ class TensorflowModel(Model):
return tf.convert_to_tensor(data)
# Apply it to the the batched elements.
return tf.nest.map_structure(convert, batch)
# fmt: on
@tf.function # optimize tensorflow runtime
def _compute_batch_gradients(
......
......@@ -285,7 +285,7 @@ def test_declearn_baseline(
d_train = InMemoryDataset(train[0], train[1])
# Set up a declearn model and a vanilla SGD optimizer.
model = get_model("numpy")
model.initialize({"n_features": d_train.data.shape[1]})
model.initialize({"single_input_shape": (d_train.data.shape[1],)})
opt = Optimizer(lrate=lrate, regularizers=[("lasso", {"alpha": 0.1})])
# Iteratively train the model, evaluating it after each epoch.
for _ in range(rounds):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment