Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit ff655eca authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Patch 'SklearnSGDModel' for scikit-learn 1.3 compatibility.

parent 8f450036
No related branches found
No related tags found
No related merge requests found
......@@ -270,8 +270,8 @@ class SklearnSGDModel(Model):
raise TypeError(
f"Missing required '{key}' in the received vector."
)
self._model.coef_ = weights.coefs["coef"].copy()
self._model.intercept_ = weights.coefs["intercept"].copy()
self._model.coef_ = weights.coefs["coef"].astype("float64")
self._model.intercept_ = weights.coefs["intercept"].astype("float64")
def compute_batch_gradients(
self,
......@@ -313,6 +313,7 @@ class SklearnSGDModel(Model):
"'SklearnSGDModel' requires (array, array, [array|None]) "
"data batches."
)
x_data = x_data.astype("float64", copy=False) # type: ignore
return x_data, y_data, s_wght # type: ignore
def _compute_sample_gradient(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment