From ff655eca8628765fd296636329705ece47bce41f Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 4 Aug 2023 10:23:51 +0200 Subject: [PATCH] Patch 'SklearnSGDModel' for scikit-learn 1.3 compatibility. --- declearn/model/sklearn/_sgd.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index c53cab85..f3dd3d7a 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -270,8 +270,8 @@ class SklearnSGDModel(Model): raise TypeError( f"Missing required '{key}' in the received vector." ) - self._model.coef_ = weights.coefs["coef"].copy() - self._model.intercept_ = weights.coefs["intercept"].copy() + self._model.coef_ = weights.coefs["coef"].astype("float64") + self._model.intercept_ = weights.coefs["intercept"].astype("float64") def compute_batch_gradients( self, @@ -313,6 +313,7 @@ class SklearnSGDModel(Model): "'SklearnSGDModel' requires (array, array, [array|None]) " "data batches." ) + x_data = x_data.astype("float64", copy=False) # type: ignore return x_data, y_data, s_wght # type: ignore def _compute_sample_gradient( -- GitLab