Mentions légales du service

Skip to content
Snippets Groups Projects
Commit c9f52458 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Merge branch 'sklearn-13-dtype-fix' into 'develop'

Fix compatibility issues with scikit-learn 1.3.0 and tensorflow 2.13.0

Closes #28 and #29

See merge request !50
parents 3b687a3a 38accb64
No related branches found
No related tags found
1 merge request!50Fix compatibility issues with scikit-learn 1.3.0 and tensorflow 2.13.0
Pipeline #831148 passed
......@@ -19,10 +19,16 @@
import typing
import warnings
from typing import Any, Callable, Dict, Literal, Optional, Set, Tuple, Union
from typing import (
# fmt: off
Any, Callable, Dict, Literal, Optional, Set, Tuple, Type, Union
)
import numpy as np
import pandas as pd
import sklearn # type: ignore
from numpy.typing import ArrayLike
from scipy.sparse import spmatrix # type: ignore
from sklearn.linear_model import SGDClassifier, SGDRegressor # type: ignore
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -56,6 +62,54 @@ REG_LOSSES = (
)
DataArray = Union[np.ndarray, spmatrix, pd.DataFrame, pd.Series]
def select_sgd_model_dtype(
model: Union[SGDClassifier, SGDRegressor],
dtype: Union[str, np.dtype, Type[np.number]],
) -> str:
"""Return the current or future dtype of a scikit-learn SGD model.
Emit a `RuntimeWarning` if the model has already been initialized
with a dtype that does not match the input one, and/or if the
scikit-learn version makes it impossible to enforce a non-default
dtype.
"""
# Identify the user-defined dtype's name.
dtype = _get_dtype_name(dtype)
# Warn if the dtype is already defined and does not match the argument.
if hasattr(model, "coef_") and (model.coef_.dtype.name != dtype):
warnings.warn(
f"Cannot enforce dtype '{dtype}' for pre-initialized "
f"scikit-learn SGD model (dtype '{model.coef_.dtype.name}').",
RuntimeWarning,
)
return model.coef_.dtype.name
# When using scikit-learn <= 1.3, warn about un-settable dtype.
if int(sklearn.__version__.split(".", 2)[1]) < 3 and (dtype != "float64"):
warnings.warn(
"Using scikit-learn <1.3; hence the 'float64' dtype will"
f"forcibly be used rather than user-input '{dtype}'.",
RuntimeWarning,
)
return "float64"
return dtype
def _get_dtype_name(
dtype: Union[str, np.dtype, Type[np.number]],
) -> str:
"""Gather the name of an input numpy dtype."""
if isinstance(dtype, np.dtype):
return dtype.name
if isinstance(dtype, type) and issubclass(dtype, np.number):
return dtype.__name__
if isinstance(dtype, str):
return dtype
raise TypeError("'dtype' should be a str, np.dtype or np.number type.")
@register_type(name="SklearnSGDModel", group="Model")
class SklearnSGDModel(Model):
"""Model wrapper for Scikit-Learn SGDClassifier and SGDRegressor.
......@@ -76,6 +130,7 @@ class SklearnSGDModel(Model):
def __init__(
self,
model: Union[SGDClassifier, SGDRegressor],
dtype: Union[str, np.dtype, Type[np.number]] = "float64",
) -> None:
"""Instantiate a Model interfacing a sklearn SGD-based model.
......@@ -89,6 +144,14 @@ class SklearnSGDModel(Model):
Scikit-learn model that needs wrapping for federated training.
Note that some hyperparameters will be overridden, as will the
model's existing weights (if any).
dtype: str or numpy.dtype or type[np.number], default="float64"
Data type to enforce for the model's coefficients. Input data
will be cast to the matching dtype.
Only used when both these conditions are met:
- `model` is an un-initialized instance;
otherwise the dtype of existing coefficients is used.
- the scikit-learn version is >= 1.3;
otherwise the dtype is forced to float64
"""
if not isinstance(model, (SGDClassifier, SGDRegressor)):
raise TypeError(
......@@ -102,7 +165,7 @@ class SklearnSGDModel(Model):
average=False,
)
super().__init__(model)
self._initialized = False
self._dtype = select_sgd_model_dtype(model, dtype)
self._predict = (
self._model.decision_function
if isinstance(model, SGDClassifier)
......@@ -122,6 +185,8 @@ class SklearnSGDModel(Model):
def required_data_info(
self,
) -> Set[str]:
if hasattr(self._model, "coef_"):
return set()
if isinstance(self._model, SGDRegressor):
return {"features_shape"}
return {"features_shape", "classes"}
......@@ -130,6 +195,9 @@ class SklearnSGDModel(Model):
self,
data_info: Dict[str, Any],
) -> None:
# Skip for pre-initialized models.
if hasattr(self._model, "coef_"):
return
# Check that required fields are available and of valid type.
data_info = aggregate_data_info([data_info], self.required_data_info)
if not (
......@@ -145,14 +213,12 @@ class SklearnSGDModel(Model):
self._model.classes_ = np.array(list(data_info["classes"]))
n_classes = len(self._model.classes_)
dim = n_classes if (n_classes > 2) else 1
self._model.coef_ = np.zeros((dim, feat))
self._model.intercept_ = np.zeros((dim,))
self._model.coef_ = np.zeros((dim, feat), dtype=self._dtype)
self._model.intercept_ = np.zeros((dim,), dtype=self._dtype)
# SGDRegressor case.
else:
self._model.coef_ = np.zeros((feat,))
self._model.intercept_ = np.zeros((1,))
# Mark the SklearnSGDModel as initialized.
self._initialized = True
self._model.coef_ = np.zeros((feat,), dtype=self._dtype)
self._model.intercept_ = np.zeros((1,), dtype=self._dtype)
@classmethod
def from_parameters(
......@@ -165,6 +231,7 @@ class SklearnSGDModel(Model):
epsilon: float = 0.1,
fit_intercept: bool = True,
n_jobs: Optional[int] = None,
dtype: Union[str, np.dtype, Type[np.number]] = "float64",
) -> Self:
"""Instantiate a SklearnSGDModel from model parameters.
......@@ -198,6 +265,14 @@ class SklearnSGDModel(Model):
Number of CPUs to use when to compute one-versus-all.
Only used for multi-class classifiers.
`None` means 1, while -1 means all available CPUs.
dtype: str or numpy.dtype or type[np.number], default="float64"
Data type to enforce for the model's coefficients. Input data
will be cast to the matching dtype.
Only used when both these conditions are met:
- `model` is an un-initialized instance;
otherwise the dtype of existing coefficients is used.
- the scikit-learn version is >= 1.3;
otherwise the dtype is forced to float64
Notes
-----
......@@ -237,22 +312,26 @@ class SklearnSGDModel(Model):
fit_intercept=fit_intercept,
**kwargs,
)
return cls(model)
return cls(model, dtype)
def get_config(
self,
) -> Dict[str, Any]:
is_clf = isinstance(self._model, SGDClassifier)
data_info = None # type: Optional[Dict[str, Any]]
if self._initialized:
if hasattr(self._model, "coef_"):
data_info = {
"features_shape": (self._model.coef_.shape[-1],),
"classes": self._model.classes_.tolist() if is_clf else None,
}
dtype = self._model.coef_.dtype.name
else:
dtype = self._dtype
return {
"kind": "classifier" if is_clf else "regressor",
"params": self._model.get_params(),
"data_info": data_info,
"dtype": dtype,
}
@classmethod
......@@ -261,13 +340,14 @@ class SklearnSGDModel(Model):
config: Dict[str, Any],
) -> Self:
"""Instantiate a SklearnSGDModel from a configuration dict."""
for key in ("kind", "params"):
for key in ("kind", "params", "dtype"):
if key not in config:
raise KeyError(f"Missing key '{key}' in the config dict.")
if config["kind"] == "classifier":
model = cls(SGDClassifier(**config["params"]))
skmod = SGDClassifier(**config["params"])
else:
model = cls(SGDRegressor(**config["params"]))
skmod = SGDRegressor(**config["params"])
model = cls(skmod, dtype=config["dtype"])
if config.get("data_info"):
model.initialize(config["data_info"])
return model
......@@ -306,8 +386,7 @@ class SklearnSGDModel(Model):
x_data, y_data, s_wght = self._unpack_batch(batch)
# Iteratively compute sample-wise gradients.
grad = [
self._compute_sample_gradient(x, y) # type: ignore
for x, y in zip(x_data, y_data) # type: ignore
self._compute_sample_gradient(x, y) for x, y in zip(x_data, y_data)
]
# Optionally clip sample-wise gradients based on their L2 norm.
if max_norm:
......@@ -317,27 +396,45 @@ class SklearnSGDModel(Model):
arr *= min(max_norm / norm, 1)
# Optionally re-weight gradients based on sample weights.
if s_wght is not None:
grad = [g * w for g, w in zip(grad, s_wght)] # type: ignore
grad = [g * w for g, w in zip(grad, s_wght)]
# Batch-average the gradients and return them.
return sum(grad) / len(grad) # type: ignore
def _unpack_batch(
self,
batch: Batch,
) -> Tuple[ArrayLike, ArrayLike, Optional[ArrayLike]]:
) -> Tuple[DataArray, DataArray, Optional[DataArray]]:
"""Verify and unpack an input batch into (x, y, [w]).
Note: this method does not verify arrays' dimensionality or
shape coherence; the wrapped sklearn objects already do so.
"""
x_data, y_data, s_wght = batch
invalid = (y_data is None) or isinstance(y_data, list)
if invalid or isinstance(x_data, list):
if (
(y_data is None)
or isinstance(y_data, list)
or isinstance(x_data, list)
):
raise TypeError(
"'SklearnSGDModel' requires (array, array, [array|None]) "
"data batches."
)
return x_data, y_data, s_wght # type: ignore
x_data = self._validate_and_cast_array(x_data)
y_data = self._validate_and_cast_array(y_data)
if s_wght is not None:
s_wght = self._validate_and_cast_array(s_wght)
return x_data, y_data, s_wght
def _validate_and_cast_array(
self,
array: ArrayLike,
) -> DataArray:
"""Type-check and type-cast an input data array."""
if not isinstance(array, typing.get_args(DataArray)):
raise TypeError(
f"Invalid data type for 'SklearnSGDModel': '{type(array)}'."
)
return array.astype(self._dtype, copy=False) # type: ignore
def _compute_sample_gradient(
self,
......
......@@ -85,22 +85,9 @@ def build_keras_loss(
# Case when 'loss' is already a Loss object.
if isinstance(loss, tf.keras.losses.Loss):
loss.reduction = reduction
# Case when 'loss' is a string.
# Case when 'loss' is a string: deserialize and/or wrap into a Loss object.
elif isinstance(loss, str):
cls = tf.keras.losses.deserialize(loss)
# Case when the string was deserialized into a function.
if inspect.isfunction(cls):
# Try altering the string to gather its object counterpart.
loss = "".join(word.capitalize() for word in loss.split("_"))
try:
loss = tf.keras.losses.deserialize(loss)
loss.reduction = reduction
# If this failed, try wrapping the function using LossFunction.
except ValueError:
loss = LossFunction(cls)
# Case when the string was deserialized into a class.
else:
loss = cls(reduction=reduction)
loss = get_keras_loss_from_string(name=loss, reduction=reduction)
# Case when 'loss' is a function: wrap it up using LossFunction.
elif inspect.isfunction(loss):
loss = LossFunction(loss, reduction=reduction)
......@@ -111,3 +98,32 @@ def build_keras_loss(
)
# Otherwise, properly configure the reduction scheme and return.
return loss
def get_keras_loss_from_string(
name: str,
reduction: str,
) -> tf.keras.losses.Loss:
"""Instantiate a keras Loss object from a registered string identifier.
- If `name` matches a Loss registration name, return an instance.
- If it matches a loss function registration name, return either
an instance from its name-matching Loss subclass, or a custom
Loss subclass instance wrapping the function.
- If it does not match anything, raise a ValueError.
"""
loss = tf.keras.losses.deserialize(name)
if isinstance(loss, tf.keras.losses.Loss):
loss.reduction = reduction
return loss
if inspect.isfunction(loss):
try:
name = "".join(word.capitalize() for word in name.split("_"))
return get_keras_loss_from_string(name, reduction)
except ValueError:
return LossFunction(
loss, reduction=reduction, name=getattr(loss, "__name__", None)
)
raise ValueError(
f"Name '{loss}' cannot be deserialized into a keras loss."
)
......@@ -89,7 +89,8 @@ class DeclearnTestCase:
"""Return a Model suitable for the learning task and framework."""
if self.framework.lower() == "sksgd":
return SklearnSGDModel.from_parameters(
kind=("regressor" if self.kind == "Reg" else "classifier")
kind=("regressor" if self.kind == "Reg" else "classifier"),
dtype="float32",
)
if self.framework.lower() == "tflow":
return self._build_tflow_model()
......@@ -274,7 +275,7 @@ def run_test_case(
)
@pytest.mark.parametrize("strategy", ["FedAvg", "FedAvgM", "Scaffold"])
@pytest.mark.parametrize("strategy", ["FedAvg", "Scaffold"])
@pytest.mark.parametrize("framework", FRAMEWORKS)
@pytest.mark.parametrize("kind", ["Reg", "Bin", "Clf"])
@pytest.mark.filterwarnings("ignore: PyTorch JSON serialization")
......@@ -293,7 +294,7 @@ def test_declearn(
Note: If websockets is unavailable, use gRPC (warn) or fail.
"""
if not fulltest:
if (kind != "Reg") or (strategy == "FedAvgM"):
if (kind != "Reg") or (strategy == "FedAvg"):
pytest.skip("skip scenario (no --fulltest option)")
protocol = "websockets" # type: Literal["grpc", "websockets"]
if "websockets" not in list_available_protocols():
......
......@@ -70,7 +70,7 @@ class SklearnSGDTestCase(ModelTestCase):
assert isinstance(tensor, (np.ndarray, csr_matrix))
if isinstance(tensor, csr_matrix):
tensor = tensor.toarray()
return tensor # type: ignore
return tensor
@property
def dataset(
......@@ -78,15 +78,15 @@ class SklearnSGDTestCase(ModelTestCase):
) -> List[Batch]:
"""Suited toy binary-classification dataset."""
rng = np.random.default_rng(seed=0)
inputs = rng.normal(size=(2, 32, 8))
inputs = rng.normal(size=(2, 32, 8)).astype("float32")
if self.as_sparse:
inputs = [csr_matrix(arr) for arr in inputs] # type: ignore
if isinstance(self.n_classes, int):
labels = rng.choice(self.n_classes, size=(2, 32)).astype(float)
labels = rng.choice(self.n_classes, size=(2, 32)).astype("float32")
else:
labels = rng.normal(size=(2, 32))
labels = rng.normal(size=(2, 32)).astype("float32")
if self.s_weights:
s_wght = np.exp(rng.normal(size=(2, 32)))
s_wght = np.exp(rng.normal(size=(2, 32)).astype("float32"))
s_wght /= s_wght.sum(axis=1, keepdims=True) * 32
batches = list(zip(inputs, labels, s_wght))
else:
......@@ -99,7 +99,7 @@ class SklearnSGDTestCase(ModelTestCase):
) -> SklearnSGDModel:
"""Suited toy binary-classification model."""
skmod = (SGDClassifier if self.n_classes else SGDRegressor)()
model = SklearnSGDModel(skmod)
model = SklearnSGDModel(skmod, dtype="float32")
data_info = {"features_shape": (8,)} # type: Dict[str, Any]
if self.n_classes:
data_info["classes"] = np.arange(self.n_classes)
......
......@@ -32,6 +32,7 @@ except ModuleNotFoundError:
pytest.skip("TensorFlow is unavailable", allow_module_level=True)
from declearn.model.tensorflow import TensorflowModel, TensorflowVector
from declearn.model.tensorflow.utils import build_keras_loss
from declearn.typing import Batch
from declearn.utils import set_device_policy
......@@ -221,3 +222,66 @@ class TestTensorflowModel(ModelTestSuite):
device = f"{test_case.device}:0"
for var in tfmod.weights:
assert var.device.endswith(device)
class TestBuildKerasLoss:
"""Unit tests for `build_keras_loss` util function."""
def test_build_keras_loss_from_string_class_name(self) -> None:
"""Test `build_keras_loss` with a valid class name string input."""
loss = build_keras_loss(
"BinaryCrossentropy", tf.keras.losses.Reduction.SUM
)
assert isinstance(loss, tf.keras.losses.BinaryCrossentropy)
assert loss.reduction == tf.keras.losses.Reduction.SUM
def test_build_keras_loss_from_string_function_name(self) -> None:
"""Test `build_keras_loss` with a valid function name string input."""
loss = build_keras_loss(
"binary_crossentropy", tf.keras.losses.Reduction.SUM
)
assert isinstance(loss, tf.keras.losses.BinaryCrossentropy)
assert loss.reduction == tf.keras.losses.Reduction.SUM
def test_build_keras_loss_from_string_noclass_function_name(self) -> None:
"""Test `build_keras_loss` with a valid function name string input."""
loss = build_keras_loss("mse", tf.keras.losses.Reduction.SUM)
assert isinstance(loss, tf.keras.losses.Loss)
assert hasattr(loss, "loss_fn")
assert loss.loss_fn is tf.keras.losses.mse
assert loss.reduction == tf.keras.losses.Reduction.SUM
def test_build_keras_loss_from_loss_instance(self) -> None:
"""Test `build_keras_loss` with a valid keras Loss input."""
# Set up a BinaryCrossentropy loss instance.
loss = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM
)
assert loss.reduction == tf.keras.losses.Reduction.SUM
# Pass it through the util and verify that reduction changes.
loss = build_keras_loss(loss, tf.keras.losses.Reduction.NONE)
assert isinstance(loss, tf.keras.losses.BinaryCrossentropy)
assert loss.reduction == tf.keras.losses.Reduction.NONE
def test_build_keras_loss_from_loss_function(self) -> None:
"""Test `build_keras_loss` with a valid keras loss function input."""
loss = build_keras_loss(
tf.keras.losses.binary_crossentropy, tf.keras.losses.Reduction.SUM
)
assert isinstance(loss, tf.keras.losses.Loss)
assert hasattr(loss, "loss_fn")
assert loss.loss_fn is tf.keras.losses.binary_crossentropy
assert loss.reduction == tf.keras.losses.Reduction.SUM
def test_build_keras_loss_from_custom_function(self) -> None:
"""Test `build_keras_loss` with a valid custom loss function input."""
def loss_fn(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
"""Custom loss function."""
return tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.float32))
loss = build_keras_loss(loss_fn, tf.keras.losses.Reduction.SUM)
assert isinstance(loss, tf.keras.losses.Loss)
assert hasattr(loss, "loss_fn")
assert loss.loss_fn is loss_fn
assert loss.reduction == tf.keras.losses.Reduction.SUM
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment