Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 811699ca authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix expanded-dim inputs handling in MAE, MSE and R2 metrics.

Until now, when provided with y_true:<shape>, y_pred:(<shape>, 1)
inputs, the MAE, MSE and R2 metrics would be entirely wrong, due
to the way numpy casts operations between such inputs. This patch
adds some shape-verification and squeezing operations that fix the
computations. Unit tests were added to cover this case.
parent cfb294ea
No related branches found
No related tags found
1 merge request!57Improve tests coverage and fix test-digged bugs
......@@ -23,6 +23,7 @@ from typing import ClassVar, Dict, Optional, Union
import numpy as np
from declearn.metrics._api import Metric
from declearn.metrics._utils import squeeze_into_identical_shapes
__all__ = [
"MeanMetric",
......@@ -129,6 +130,7 @@ class MeanAbsoluteError(MeanMetric):
y_pred: np.ndarray,
) -> np.ndarray:
# Sample-wise (sum of) absolute error function.
y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred)
errors = np.abs(y_true - y_pred)
while errors.ndim > 1:
errors = errors.sum(axis=-1)
......@@ -158,6 +160,7 @@ class MeanSquaredError(MeanMetric):
y_pred: np.ndarray,
) -> np.ndarray:
# Sample-wise (sum of) squared error function.
y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred)
errors = np.square(y_true - y_pred)
while errors.ndim > 1:
errors = errors.sum(axis=-1)
......
......@@ -22,6 +22,7 @@ from typing import ClassVar, Dict, Optional, Union
import numpy as np
from declearn.metrics._api import Metric
from declearn.metrics._utils import squeeze_into_identical_shapes
__all__ = [
"RSquared",
......@@ -113,6 +114,7 @@ class RSquared(Metric):
y_pred: np.ndarray,
s_wght: Optional[np.ndarray] = None,
) -> None:
y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred)
# Verify sample weights' shape, or set up 1-valued ones.
s_wght = self._prepare_sample_weights(s_wght, n_samples=len(y_pred))
# Update the residual sum of squares. wSSr = sum(w * (y - p)^2)
......
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Backend utils for metrics' computations."""
from typing import Tuple
import numpy as np
__all__ = [
"squeeze_into_identical_shapes",
]
def squeeze_into_identical_shapes(
y_true: np.ndarray,
y_pred: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""Verify that inputs have identical shapes, up to squeezable dims.
Return the input arrays, squeezed when needed.
Raise a ValueError if they cannot be made to match.
"""
# Case of identical-shape inputs.
if y_true.shape == y_pred.shape:
return y_true, y_pred
# Case of identical-shape inputs up to squeezable dims.
y_true = np.squeeze(y_true)
y_pred = np.squeeze(y_pred)
if y_true.shape == y_pred.shape:
# Handle edge case of scalar values: preserve one dimension.
if not y_true.shape:
y_true = np.expand_dims(y_true, 0)
y_pred = np.expand_dims(y_pred, 0)
return y_true, y_pred
# Case of mismatching shapes.
raise ValueError(
"Received inputs with incompatible shapes: "
f"y_true has shape {y_true.shape}, y_pred has shape {y_pred.shape}."
)
......@@ -24,7 +24,7 @@ import numpy as np
import pytest
from declearn.metrics import MeanAbsoluteError, MeanSquaredError, Metric
from declearn.test_utils import make_importable
from declearn.test_utils import assert_dict_equal, make_importable
# relative imports from `metric_testing.py`
with make_importable(os.path.dirname(__file__)):
......@@ -69,7 +69,7 @@ class MeanMetricTestSuite(MetricTestSuite):
"""Unit tests suite for `MeanMetric` subclasses."""
def test_update_errors(self, test_case: MetricTestCase) -> None:
"""Test that `update` raises on improper `s_wght` shapes."""
"""Test that `update` raises on improper input shapes."""
metric = test_case.metric
inputs = test_case.inputs
# Test with multi-dimensional sample weights.
......@@ -80,12 +80,34 @@ class MeanMetricTestSuite(MetricTestSuite):
s_wght = np.ones(shape=(len(inputs["y_pred"]) + 2,))
with pytest.raises(ValueError):
metric.update(inputs["y_true"], inputs["y_pred"], s_wght)
# Test with mismatching-shape inputs.
y_true = inputs["y_true"]
y_pred = np.stack([inputs["y_pred"], inputs["y_pred"]], axis=-1)
with pytest.raises(ValueError):
metric.update(y_true, y_pred, s_wght)
def test_zero_result(self, test_case: MetricTestCase) -> None:
"""Test that `get_results` works with zero-valued divisor."""
metric = test_case.metric
assert metric.get_result() == {metric.name: 0.0}
def test_update_expanded_shape(self, test_case: MetricTestCase) -> None:
"""Test that the metric supports expanded-dim input predictions."""
# Gather states with basic inputs.
metric, inputs = test_case.metric, test_case.inputs
metric.update(**inputs)
states = metric.get_states()
metric.reset()
# Do the same with expanded-dim predictions.
metric.update(
inputs["y_true"],
np.expand_dims(inputs["y_pred"], -1),
inputs.get("s_wght"),
)
st_bis = metric.get_states()
# Verify that results are the same.
assert_dict_equal(states, st_bis)
@pytest.mark.parametrize("weighted", [False, True], ids=["base", "weighted"])
@pytest.mark.parametrize("case", ["mae"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment