From ff655eca8628765fd296636329705ece47bce41f Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 4 Aug 2023 10:23:51 +0200
Subject: [PATCH] Patch 'SklearnSGDModel' for scikit-learn 1.3 compatibility.

---
 declearn/model/sklearn/_sgd.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index c53cab85..f3dd3d7a 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -270,8 +270,8 @@ class SklearnSGDModel(Model):
                 raise TypeError(
                     f"Missing required '{key}' in the received vector."
                 )
-        self._model.coef_ = weights.coefs["coef"].copy()
-        self._model.intercept_ = weights.coefs["intercept"].copy()
+        self._model.coef_ = weights.coefs["coef"].astype("float64")
+        self._model.intercept_ = weights.coefs["intercept"].astype("float64")
 
     def compute_batch_gradients(
         self,
@@ -313,6 +313,7 @@ class SklearnSGDModel(Model):
                 "'SklearnSGDModel' requires (array, array, [array|None]) "
                 "data batches."
             )
+        x_data = x_data.astype("float64", copy=False)  # type: ignore
         return x_data, y_data, s_wght  # type: ignore
 
     def _compute_sample_gradient(
-- 
GitLab