Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 38accb64 authored by BIGAUD Nathan's avatar BIGAUD Nathan Committed by ANDREY Paul
Browse files

Refactor 'SklearnSGModel' dtype choice code into a private function.


Co-authored-by: default avatarPaul ANDREY <paul.andrey@inria.fr>
parent c4fe82b5
No related branches found
No related tags found
1 merge request!50Fix compatibility issues with scikit-learn 1.3.0 and tensorflow 2.13.0
Pipeline #831139 passed
......@@ -25,9 +25,9 @@ from typing import (
)
import numpy as np
from numpy.typing import ArrayLike
import pandas as pd
import sklearn # type: ignore
from numpy.typing import ArrayLike
from scipy.sparse import spmatrix # type: ignore
from sklearn.linear_model import SGDClassifier, SGDRegressor # type: ignore
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -65,6 +65,51 @@ REG_LOSSES = (
DataArray = Union[np.ndarray, spmatrix, pd.DataFrame, pd.Series]
def select_sgd_model_dtype(
model: Union[SGDClassifier, SGDRegressor],
dtype: Union[str, np.dtype, Type[np.number]],
) -> str:
"""Return the current or future dtype of a scikit-learn SGD model.
Emit a `RuntimeWarning` if the model has already been initialized
with a dtype that does not match the input one, and/or if the
scikit-learn version makes it impossible to enforce a non-default
dtype.
"""
# Identify the user-defined dtype's name.
dtype = _get_dtype_name(dtype)
# Warn if the dtype is already defined and does not match the argument.
if hasattr(model, "coef_") and (model.coef_.dtype.name != dtype):
warnings.warn(
f"Cannot enforce dtype '{dtype}' for pre-initialized "
f"scikit-learn SGD model (dtype '{model.coef_.dtype.name}').",
RuntimeWarning,
)
return model.coef_.dtype.name
# When using scikit-learn <= 1.3, warn about un-settable dtype.
if int(sklearn.__version__.split(".", 2)[1]) < 3 and (dtype != "float64"):
warnings.warn(
"Using scikit-learn <1.3; hence the 'float64' dtype will"
f"forcibly be used rather than user-input '{dtype}'.",
RuntimeWarning,
)
return "float64"
return dtype
def _get_dtype_name(
dtype: Union[str, np.dtype, Type[np.number]],
) -> str:
"""Gather the name of an input numpy dtype."""
if isinstance(dtype, np.dtype):
return dtype.name
if isinstance(dtype, type) and issubclass(dtype, np.number):
return dtype.__name__
if isinstance(dtype, str):
return dtype
raise TypeError("'dtype' should be a str, np.dtype or np.number type.")
@register_type(name="SklearnSGDModel", group="Model")
class SklearnSGDModel(Model):
"""Model wrapper for Scikit-Learn SGDClassifier and SGDRegressor.
......@@ -120,26 +165,7 @@ class SklearnSGDModel(Model):
average=False,
)
super().__init__(model)
if hasattr(model, "coef_"):
self._dtype = self._model.coef_.dtype.name
elif isinstance(dtype, np.dtype):
self._dtype = dtype.name
elif isinstance(dtype, type) and issubclass(dtype, np.number):
self._dtype = dtype.__name__
elif isinstance(dtype, str):
self._dtype = dtype
else:
raise TypeError(
"'dtype' should be a str, np.dtype or np.number type."
)
if int(sklearn.__version__.split(".", 2)[1]) < 3:
if self._dtype != "float64":
warnings.warn(
"Using scikit-learn <1.3; hence the float64 dtype will"
f"forcibly be used rather than user-input {self._dtype}.",
RuntimeWarning,
)
self._dtype = "float64"
self._dtype = select_sgd_model_dtype(model, dtype)
self._predict = (
self._model.decision_function
if isinstance(model, SGDClassifier)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment