diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index 160a6d3970295c4b2e6c6fd66653b8b8f945b07d..17263da5a57ced83ff95517307ddf0dfda23bfbb 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -489,12 +489,13 @@ class SklearnSGDModel(Model): ) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: """Return a function to compute point-wise loss for a given batch.""" # fmt: off - # Gather / instantiate a loss function from the wrapped model's specs. - if hasattr(self._model, "loss_function_"): - loss_smp = self._model.loss_function_.py_loss - else: - loss_cls, *args = self._model.loss_functions[self._model.loss] - loss_smp = loss_cls(*args).py_loss + # Instantiate a loss function from the wrapped model's specs. + loss_cls, *args = self._model.loss_functions[self._model.loss] + if self._model.loss in ( + "huber", "epsilon_insensitive", "squared_epsilon_insensitive" + ): + args = (self._model.epsilon,) + loss_smp = loss_cls(*args).py_loss # Wrap it to support batched inputs. def loss_1d(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: return np.array([loss_smp(*smp) for smp in zip(y_pred, y_true)])