From 02afef6ec467896cc9fdcdaab8df238860e40ef1 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Wed, 5 Jul 2023 18:07:21 +0200
Subject: [PATCH] Improve dtype handling in 'SklearnSGDModel'.

- Scikit-Learn 1.3.0 introduced the possibility to use any dtype for
  SGD-based models' weights.
- As a consequence, this commit introduces the optional 'dtype' argument
  to 'SklearnSGDModel.__init__' and '.from_parameters' methods, which is
  only used with `sickit-learn >=1.3` and an un-initialized model.
- Coherently (and for all scikit-learn versions), additional type/dtype
  verification has been implemented under the hood of `_unpack_batch`.
---
 declearn/model/sklearn/_sgd.py | 111 +++++++++++++++++++++++++++------
 test/functional/test_main.py   |   3 +-
 2 files changed, 93 insertions(+), 21 deletions(-)

diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index 6ab4c880..7825fa65 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -19,10 +19,16 @@
 
 import typing
 import warnings
-from typing import Any, Callable, Dict, Literal, Optional, Set, Tuple, Union
+from typing import (
+    # fmt: off
+    Any, Callable, Dict, Literal, Optional, Set, Tuple, Type, Union
+)
 
 import numpy as np
 from numpy.typing import ArrayLike
+import pandas as pd
+import sklearn  # type: ignore
+from scipy.sparse import spmatrix  # type: ignore
 from sklearn.linear_model import SGDClassifier, SGDRegressor  # type: ignore
 from typing_extensions import Self  # future: import from typing (py >=3.11)
 
@@ -56,6 +62,9 @@ REG_LOSSES = (
 )
 
 
+DataArray = Union[np.ndarray, spmatrix, pd.DataFrame, pd.Series]
+
+
 @register_type(name="SklearnSGDModel", group="Model")
 class SklearnSGDModel(Model):
     """Model wrapper for Scikit-Learn SGDClassifier and SGDRegressor.
@@ -76,6 +85,7 @@ class SklearnSGDModel(Model):
     def __init__(
         self,
         model: Union[SGDClassifier, SGDRegressor],
+        dtype: Union[str, np.dtype, Type[np.number]] = "float64",
     ) -> None:
         """Instantiate a Model interfacing a sklearn SGD-based model.
 
@@ -89,6 +99,14 @@ class SklearnSGDModel(Model):
             Scikit-learn model that needs wrapping for federated training.
             Note that some hyperparameters will be overridden, as will the
             model's existing weights (if any).
+        dtype: str or numpy.dtype or type[np.number], default="float64"
+            Data type to enforce for the model's coefficients. Input data
+            will be cast to the matching dtype.
+            Only used when both these conditions are met:
+                - `model` is an un-initialized instance;
+                  otherwise the dtype of existing coefficients is used.
+                - the scikit-learn version is >= 1.3;
+                  otherwise the dtype is forced to float64
         """
         if not isinstance(model, (SGDClassifier, SGDRegressor)):
             raise TypeError(
@@ -102,7 +120,26 @@ class SklearnSGDModel(Model):
             average=False,
         )
         super().__init__(model)
-        self._initialized = False
+        if hasattr(model, "coef_"):
+            self._dtype = self._model.coef_.dtype.name
+        elif isinstance(dtype, np.dtype):
+            self._dtype = dtype.name
+        elif isinstance(dtype, type) and issubclass(dtype, np.number):
+            self._dtype = dtype.__name__
+        elif isinstance(dtype, str):
+            self._dtype = dtype
+        else:
+            raise TypeError(
+                "'dtype' should be a str, np.dtype or np.number type."
+            )
+        if int(sklearn.__version__.split(".", 2)[1]) < 3:
+            if self._dtype != "float64":
+                warnings.warn(
+                    "Using scikit-learn <1.3; hence the float64 dtype will"
+                    f"forcibly be used rather than user-input {self._dtype}.",
+                    RuntimeWarning,
+                )
+                self._dtype = "float64"
         self._predict = (
             self._model.decision_function
             if isinstance(model, SGDClassifier)
@@ -122,6 +159,8 @@ class SklearnSGDModel(Model):
     def required_data_info(
         self,
     ) -> Set[str]:
+        if hasattr(self._model, "coef_"):
+            return set()
         if isinstance(self._model, SGDRegressor):
             return {"features_shape"}
         return {"features_shape", "classes"}
@@ -130,6 +169,9 @@ class SklearnSGDModel(Model):
         self,
         data_info: Dict[str, Any],
     ) -> None:
+        # Skip for pre-initialized models.
+        if hasattr(self._model, "coef_"):
+            return
         # Check that required fields are available and of valid type.
         data_info = aggregate_data_info([data_info], self.required_data_info)
         if not (
@@ -145,14 +187,12 @@ class SklearnSGDModel(Model):
             self._model.classes_ = np.array(list(data_info["classes"]))
             n_classes = len(self._model.classes_)
             dim = n_classes if (n_classes > 2) else 1
-            self._model.coef_ = np.zeros((dim, feat))
-            self._model.intercept_ = np.zeros((dim,))
+            self._model.coef_ = np.zeros((dim, feat), dtype=self._dtype)
+            self._model.intercept_ = np.zeros((dim,), dtype=self._dtype)
         # SGDRegressor case.
         else:
-            self._model.coef_ = np.zeros((feat,))
-            self._model.intercept_ = np.zeros((1,))
-        # Mark the SklearnSGDModel as initialized.
-        self._initialized = True
+            self._model.coef_ = np.zeros((feat,), dtype=self._dtype)
+            self._model.intercept_ = np.zeros((1,), dtype=self._dtype)
 
     @classmethod
     def from_parameters(
@@ -165,6 +205,7 @@ class SklearnSGDModel(Model):
         epsilon: float = 0.1,
         fit_intercept: bool = True,
         n_jobs: Optional[int] = None,
+        dtype: Union[str, np.dtype, Type[np.number]] = "float64",
     ) -> Self:
         """Instantiate a SklearnSGDModel from model parameters.
 
@@ -198,6 +239,14 @@ class SklearnSGDModel(Model):
             Number of CPUs to use when to compute one-versus-all.
             Only used for multi-class classifiers.
             `None` means 1, while -1 means all available CPUs.
+        dtype: str or numpy.dtype or type[np.number], default="float64"
+            Data type to enforce for the model's coefficients. Input data
+            will be cast to the matching dtype.
+            Only used when both these conditions are met:
+                - `model` is an un-initialized instance;
+                  otherwise the dtype of existing coefficients is used.
+                - the scikit-learn version is >= 1.3;
+                  otherwise the dtype is forced to float64
 
         Notes
         -----
@@ -237,22 +286,26 @@ class SklearnSGDModel(Model):
             fit_intercept=fit_intercept,
             **kwargs,
         )
-        return cls(model)
+        return cls(model, dtype)
 
     def get_config(
         self,
     ) -> Dict[str, Any]:
         is_clf = isinstance(self._model, SGDClassifier)
         data_info = None  # type: Optional[Dict[str, Any]]
-        if self._initialized:
+        if hasattr(self._model, "coef_"):
             data_info = {
                 "features_shape": (self._model.coef_.shape[-1],),
                 "classes": self._model.classes_.tolist() if is_clf else None,
             }
+            dtype = self._model.coef_.dtype.name
+        else:
+            dtype = self._dtype
         return {
             "kind": "classifier" if is_clf else "regressor",
             "params": self._model.get_params(),
             "data_info": data_info,
+            "dtype": dtype,
         }
 
     @classmethod
@@ -261,13 +314,14 @@ class SklearnSGDModel(Model):
         config: Dict[str, Any],
     ) -> Self:
         """Instantiate a SklearnSGDModel from a configuration dict."""
-        for key in ("kind", "params"):
+        for key in ("kind", "params", "dtype"):
             if key not in config:
                 raise KeyError(f"Missing key '{key}' in the config dict.")
         if config["kind"] == "classifier":
-            model = cls(SGDClassifier(**config["params"]))
+            skmod = SGDClassifier(**config["params"])
         else:
-            model = cls(SGDRegressor(**config["params"]))
+            skmod = SGDRegressor(**config["params"])
+        model = cls(skmod, dtype=config["dtype"])
         if config.get("data_info"):
             model.initialize(config["data_info"])
         return model
@@ -306,8 +360,7 @@ class SklearnSGDModel(Model):
         x_data, y_data, s_wght = self._unpack_batch(batch)
         # Iteratively compute sample-wise gradients.
         grad = [
-            self._compute_sample_gradient(x, y)  # type: ignore
-            for x, y in zip(x_data, y_data)  # type: ignore
+            self._compute_sample_gradient(x, y) for x, y in zip(x_data, y_data)
         ]
         # Optionally clip sample-wise gradients based on their L2 norm.
         if max_norm:
@@ -317,27 +370,45 @@ class SklearnSGDModel(Model):
                     arr *= min(max_norm / norm, 1)
         # Optionally re-weight gradients based on sample weights.
         if s_wght is not None:
-            grad = [g * w for g, w in zip(grad, s_wght)]  # type: ignore
+            grad = [g * w for g, w in zip(grad, s_wght)]
         # Batch-average the gradients and return them.
         return sum(grad) / len(grad)  # type: ignore
 
     def _unpack_batch(
         self,
         batch: Batch,
-    ) -> Tuple[ArrayLike, ArrayLike, Optional[ArrayLike]]:
+    ) -> Tuple[DataArray, DataArray, Optional[DataArray]]:
         """Verify and unpack an input batch into (x, y, [w]).
 
         Note: this method does not verify arrays' dimensionality or
         shape coherence; the wrapped sklearn objects already do so.
         """
         x_data, y_data, s_wght = batch
-        invalid = (y_data is None) or isinstance(y_data, list)
-        if invalid or isinstance(x_data, list):
+        if (
+            (y_data is None)
+            or isinstance(y_data, list)
+            or isinstance(x_data, list)
+        ):
             raise TypeError(
                 "'SklearnSGDModel' requires (array, array, [array|None]) "
                 "data batches."
             )
-        return x_data, y_data, s_wght  # type: ignore
+        x_data = self._validate_and_cast_array(x_data)
+        y_data = self._validate_and_cast_array(y_data)
+        if s_wght is not None:
+            s_wght = self._validate_and_cast_array(s_wght)
+        return x_data, y_data, s_wght
+
+    def _validate_and_cast_array(
+        self,
+        array: ArrayLike,
+    ) -> DataArray:
+        """Type-check and type-cast an input data array."""
+        if not isinstance(array, typing.get_args(DataArray)):
+            raise TypeError(
+                f"Invalid data type for 'SklearnSGDModel': '{type(array)}'."
+            )
+        return array.astype(self._dtype, copy=False)  # type: ignore
 
     def _compute_sample_gradient(
         self,
diff --git a/test/functional/test_main.py b/test/functional/test_main.py
index 1f1a0bfe..b5f179da 100644
--- a/test/functional/test_main.py
+++ b/test/functional/test_main.py
@@ -89,7 +89,8 @@ class DeclearnTestCase:
         """Return a Model suitable for the learning task and framework."""
         if self.framework.lower() == "sksgd":
             return SklearnSGDModel.from_parameters(
-                kind=("regressor" if self.kind == "Reg" else "classifier")
+                kind=("regressor" if self.kind == "Reg" else "classifier"),
+                dtype="float32",
             )
         if self.framework.lower() == "tflow":
             return self._build_tflow_model()
-- 
GitLab