From 0d78e2c91fab00eabca798fbfa84e977d336ed76 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 16 Feb 2024 17:37:03 +0100 Subject: [PATCH] Quit using a scikit-learn deprecated attribute. --- declearn/model/sklearn/_sgd.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index 160a6d39..17263da5 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -489,12 +489,13 @@ class SklearnSGDModel(Model): ) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: """Return a function to compute point-wise loss for a given batch.""" # fmt: off - # Gather / instantiate a loss function from the wrapped model's specs. - if hasattr(self._model, "loss_function_"): - loss_smp = self._model.loss_function_.py_loss - else: - loss_cls, *args = self._model.loss_functions[self._model.loss] - loss_smp = loss_cls(*args).py_loss + # Instantiate a loss function from the wrapped model's specs. + loss_cls, *args = self._model.loss_functions[self._model.loss] + if self._model.loss in ( + "huber", "epsilon_insensitive", "squared_epsilon_insensitive" + ): + args = (self._model.epsilon,) + loss_smp = loss_cls(*args).py_loss # Wrap it to support batched inputs. def loss_1d(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: return np.array([loss_smp(*smp) for smp in zip(y_pred, y_true)]) -- GitLab