diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index 7825fa6555fdc4fffb199be47a7c438183f4d7a3..22d72c08f3edce7879722aed5cf052e4d5171d31 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -25,9 +25,9 @@ from typing import ( ) import numpy as np -from numpy.typing import ArrayLike import pandas as pd import sklearn # type: ignore +from numpy.typing import ArrayLike from scipy.sparse import spmatrix # type: ignore from sklearn.linear_model import SGDClassifier, SGDRegressor # type: ignore from typing_extensions import Self # future: import from typing (py >=3.11) @@ -65,6 +65,51 @@ REG_LOSSES = ( DataArray = Union[np.ndarray, spmatrix, pd.DataFrame, pd.Series] +def select_sgd_model_dtype( + model: Union[SGDClassifier, SGDRegressor], + dtype: Union[str, np.dtype, Type[np.number]], +) -> str: + """Return the current or future dtype of a scikit-learn SGD model. + + Emit a `RuntimeWarning` if the model has already been initialized + with a dtype that does not match the input one, and/or if the + scikit-learn version makes it impossible to enforce a non-default + dtype. + """ + # Identify the user-defined dtype's name. + dtype = _get_dtype_name(dtype) + # Warn if the dtype is already defined and does not match the argument. + if hasattr(model, "coef_") and (model.coef_.dtype.name != dtype): + warnings.warn( + f"Cannot enforce dtype '{dtype}' for pre-initialized " + f"scikit-learn SGD model (dtype '{model.coef_.dtype.name}').", + RuntimeWarning, + ) + return model.coef_.dtype.name + # When using scikit-learn <= 1.3, warn about un-settable dtype. + if int(sklearn.__version__.split(".", 2)[1]) < 3 and (dtype != "float64"): + warnings.warn( + "Using scikit-learn <1.3; hence the 'float64' dtype will" + f"forcibly be used rather than user-input '{dtype}'.", + RuntimeWarning, + ) + return "float64" + return dtype + + +def _get_dtype_name( + dtype: Union[str, np.dtype, Type[np.number]], +) -> str: + """Gather the name of an input numpy dtype.""" + if isinstance(dtype, np.dtype): + return dtype.name + if isinstance(dtype, type) and issubclass(dtype, np.number): + return dtype.__name__ + if isinstance(dtype, str): + return dtype + raise TypeError("'dtype' should be a str, np.dtype or np.number type.") + + @register_type(name="SklearnSGDModel", group="Model") class SklearnSGDModel(Model): """Model wrapper for Scikit-Learn SGDClassifier and SGDRegressor. @@ -120,26 +165,7 @@ class SklearnSGDModel(Model): average=False, ) super().__init__(model) - if hasattr(model, "coef_"): - self._dtype = self._model.coef_.dtype.name - elif isinstance(dtype, np.dtype): - self._dtype = dtype.name - elif isinstance(dtype, type) and issubclass(dtype, np.number): - self._dtype = dtype.__name__ - elif isinstance(dtype, str): - self._dtype = dtype - else: - raise TypeError( - "'dtype' should be a str, np.dtype or np.number type." - ) - if int(sklearn.__version__.split(".", 2)[1]) < 3: - if self._dtype != "float64": - warnings.warn( - "Using scikit-learn <1.3; hence the float64 dtype will" - f"forcibly be used rather than user-input {self._dtype}.", - RuntimeWarning, - ) - self._dtype = "float64" + self._dtype = select_sgd_model_dtype(model, dtype) self._predict = ( self._model.decision_function if isinstance(model, SGDClassifier)