diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index 160a6d3970295c4b2e6c6fd66653b8b8f945b07d..17263da5a57ced83ff95517307ddf0dfda23bfbb 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -489,12 +489,13 @@ class SklearnSGDModel(Model):
     ) -> Callable[[np.ndarray, np.ndarray], np.ndarray]:
         """Return a function to compute point-wise loss for a given batch."""
         # fmt: off
-        # Gather / instantiate a loss function from the wrapped model's specs.
-        if hasattr(self._model, "loss_function_"):
-            loss_smp = self._model.loss_function_.py_loss
-        else:
-            loss_cls, *args = self._model.loss_functions[self._model.loss]
-            loss_smp = loss_cls(*args).py_loss
+        # Instantiate a loss function from the wrapped model's specs.
+        loss_cls, *args = self._model.loss_functions[self._model.loss]
+        if self._model.loss in (
+            "huber", "epsilon_insensitive", "squared_epsilon_insensitive"
+        ):
+            args = (self._model.epsilon,)
+        loss_smp = loss_cls(*args).py_loss
         # Wrap it to support batched inputs.
         def loss_1d(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
             return np.array([loss_smp(*smp) for smp in zip(y_pred, y_true)])