diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index c53cab85144979caabac182426cdc5aa94271d3e..f3dd3d7ac6a2d822e421ed592acb8db0307e964c 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -270,8 +270,8 @@ class SklearnSGDModel(Model): raise TypeError( f"Missing required '{key}' in the received vector." ) - self._model.coef_ = weights.coefs["coef"].copy() - self._model.intercept_ = weights.coefs["intercept"].copy() + self._model.coef_ = weights.coefs["coef"].astype("float64") + self._model.intercept_ = weights.coefs["intercept"].astype("float64") def compute_batch_gradients( self, @@ -313,6 +313,7 @@ class SklearnSGDModel(Model): "'SklearnSGDModel' requires (array, array, [array|None]) " "data batches." ) + x_data = x_data.astype("float64", copy=False) # type: ignore return x_data, y_data, s_wght # type: ignore def _compute_sample_gradient(