Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 0d78e2c9 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Quit using a scikit-learn deprecated attribute.

parent 18a459fe
No related branches found
No related tags found
1 merge request!63Finalize DecLearn v2.4.0.
......@@ -489,12 +489,13 @@ class SklearnSGDModel(Model):
) -> Callable[[np.ndarray, np.ndarray], np.ndarray]:
"""Return a function to compute point-wise loss for a given batch."""
# fmt: off
# Gather / instantiate a loss function from the wrapped model's specs.
if hasattr(self._model, "loss_function_"):
loss_smp = self._model.loss_function_.py_loss
else:
loss_cls, *args = self._model.loss_functions[self._model.loss]
loss_smp = loss_cls(*args).py_loss
# Instantiate a loss function from the wrapped model's specs.
loss_cls, *args = self._model.loss_functions[self._model.loss]
if self._model.loss in (
"huber", "epsilon_insensitive", "squared_epsilon_insensitive"
):
args = (self._model.epsilon,)
loss_smp = loss_cls(*args).py_loss
# Wrap it to support batched inputs.
def loss_1d(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
return np.array([loss_smp(*smp) for smp in zip(y_pred, y_true)])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment