diff --git a/declearn/metrics/_mean.py b/declearn/metrics/_mean.py index 2c4189c675971944dab74ab58aa393140e5e3da0..aca63e392f91d19f7e2e863a0b121d8689a2dc55 100644 --- a/declearn/metrics/_mean.py +++ b/declearn/metrics/_mean.py @@ -23,6 +23,7 @@ from typing import ClassVar, Dict, Optional, Union import numpy as np from declearn.metrics._api import Metric +from declearn.metrics._utils import squeeze_into_identical_shapes __all__ = [ "MeanMetric", @@ -129,6 +130,7 @@ class MeanAbsoluteError(MeanMetric): y_pred: np.ndarray, ) -> np.ndarray: # Sample-wise (sum of) absolute error function. + y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred) errors = np.abs(y_true - y_pred) while errors.ndim > 1: errors = errors.sum(axis=-1) @@ -158,6 +160,7 @@ class MeanSquaredError(MeanMetric): y_pred: np.ndarray, ) -> np.ndarray: # Sample-wise (sum of) squared error function. + y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred) errors = np.square(y_true - y_pred) while errors.ndim > 1: errors = errors.sum(axis=-1) diff --git a/declearn/metrics/_rsquared.py b/declearn/metrics/_rsquared.py index d61fdfeaf34fc7c6cb3a47433861e42b8ec43e5e..f00b40b2267a523d054ae6229cb8ae3a36c152c8 100644 --- a/declearn/metrics/_rsquared.py +++ b/declearn/metrics/_rsquared.py @@ -22,6 +22,7 @@ from typing import ClassVar, Dict, Optional, Union import numpy as np from declearn.metrics._api import Metric +from declearn.metrics._utils import squeeze_into_identical_shapes __all__ = [ "RSquared", @@ -113,6 +114,7 @@ class RSquared(Metric): y_pred: np.ndarray, s_wght: Optional[np.ndarray] = None, ) -> None: + y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred) # Verify sample weights' shape, or set up 1-valued ones. s_wght = self._prepare_sample_weights(s_wght, n_samples=len(y_pred)) # Update the residual sum of squares. wSSr = sum(w * (y - p)^2) diff --git a/declearn/metrics/_utils.py b/declearn/metrics/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..757637a6d560883f18ab5fc931db4615deb17c47 --- /dev/null +++ b/declearn/metrics/_utils.py @@ -0,0 +1,54 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Backend utils for metrics' computations.""" + +from typing import Tuple + +import numpy as np + +__all__ = [ + "squeeze_into_identical_shapes", +] + + +def squeeze_into_identical_shapes( + y_true: np.ndarray, + y_pred: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray]: + """Verify that inputs have identical shapes, up to squeezable dims. + + Return the input arrays, squeezed when needed. + Raise a ValueError if they cannot be made to match. + """ + # Case of identical-shape inputs. + if y_true.shape == y_pred.shape: + return y_true, y_pred + # Case of identical-shape inputs up to squeezable dims. + y_true = np.squeeze(y_true) + y_pred = np.squeeze(y_pred) + if y_true.shape == y_pred.shape: + # Handle edge case of scalar values: preserve one dimension. + if not y_true.shape: + y_true = np.expand_dims(y_true, 0) + y_pred = np.expand_dims(y_pred, 0) + return y_true, y_pred + # Case of mismatching shapes. + raise ValueError( + "Received inputs with incompatible shapes: " + f"y_true has shape {y_true.shape}, y_pred has shape {y_pred.shape}." + ) diff --git a/test/metrics/test_mae_mse.py b/test/metrics/test_mae_mse.py index cd55e880d6561de0d5b33636be8cb57e4f21fdb0..de472efdf1fa668752e61e26b14c6ff9b187a29c 100644 --- a/test/metrics/test_mae_mse.py +++ b/test/metrics/test_mae_mse.py @@ -24,7 +24,7 @@ import numpy as np import pytest from declearn.metrics import MeanAbsoluteError, MeanSquaredError, Metric -from declearn.test_utils import make_importable +from declearn.test_utils import assert_dict_equal, make_importable # relative imports from `metric_testing.py` with make_importable(os.path.dirname(__file__)): @@ -69,7 +69,7 @@ class MeanMetricTestSuite(MetricTestSuite): """Unit tests suite for `MeanMetric` subclasses.""" def test_update_errors(self, test_case: MetricTestCase) -> None: - """Test that `update` raises on improper `s_wght` shapes.""" + """Test that `update` raises on improper input shapes.""" metric = test_case.metric inputs = test_case.inputs # Test with multi-dimensional sample weights. @@ -80,12 +80,34 @@ class MeanMetricTestSuite(MetricTestSuite): s_wght = np.ones(shape=(len(inputs["y_pred"]) + 2,)) with pytest.raises(ValueError): metric.update(inputs["y_true"], inputs["y_pred"], s_wght) + # Test with mismatching-shape inputs. + y_true = inputs["y_true"] + y_pred = np.stack([inputs["y_pred"], inputs["y_pred"]], axis=-1) + with pytest.raises(ValueError): + metric.update(y_true, y_pred, s_wght) def test_zero_result(self, test_case: MetricTestCase) -> None: """Test that `get_results` works with zero-valued divisor.""" metric = test_case.metric assert metric.get_result() == {metric.name: 0.0} + def test_update_expanded_shape(self, test_case: MetricTestCase) -> None: + """Test that the metric supports expanded-dim input predictions.""" + # Gather states with basic inputs. + metric, inputs = test_case.metric, test_case.inputs + metric.update(**inputs) + states = metric.get_states() + metric.reset() + # Do the same with expanded-dim predictions. + metric.update( + inputs["y_true"], + np.expand_dims(inputs["y_pred"], -1), + inputs.get("s_wght"), + ) + st_bis = metric.get_states() + # Verify that results are the same. + assert_dict_equal(states, st_bis) + @pytest.mark.parametrize("weighted", [False, True], ids=["base", "weighted"]) @pytest.mark.parametrize("case", ["mae"])