diff --git a/declearn/metrics/_mean.py b/declearn/metrics/_mean.py
index 2c4189c675971944dab74ab58aa393140e5e3da0..aca63e392f91d19f7e2e863a0b121d8689a2dc55 100644
--- a/declearn/metrics/_mean.py
+++ b/declearn/metrics/_mean.py
@@ -23,6 +23,7 @@ from typing import ClassVar, Dict, Optional, Union
 import numpy as np
 
 from declearn.metrics._api import Metric
+from declearn.metrics._utils import squeeze_into_identical_shapes
 
 __all__ = [
     "MeanMetric",
@@ -129,6 +130,7 @@ class MeanAbsoluteError(MeanMetric):
         y_pred: np.ndarray,
     ) -> np.ndarray:
         # Sample-wise (sum of) absolute error function.
+        y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred)
         errors = np.abs(y_true - y_pred)
         while errors.ndim > 1:
             errors = errors.sum(axis=-1)
@@ -158,6 +160,7 @@ class MeanSquaredError(MeanMetric):
         y_pred: np.ndarray,
     ) -> np.ndarray:
         # Sample-wise (sum of) squared error function.
+        y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred)
         errors = np.square(y_true - y_pred)
         while errors.ndim > 1:
             errors = errors.sum(axis=-1)
diff --git a/declearn/metrics/_rsquared.py b/declearn/metrics/_rsquared.py
index d61fdfeaf34fc7c6cb3a47433861e42b8ec43e5e..f00b40b2267a523d054ae6229cb8ae3a36c152c8 100644
--- a/declearn/metrics/_rsquared.py
+++ b/declearn/metrics/_rsquared.py
@@ -22,6 +22,7 @@ from typing import ClassVar, Dict, Optional, Union
 import numpy as np
 
 from declearn.metrics._api import Metric
+from declearn.metrics._utils import squeeze_into_identical_shapes
 
 __all__ = [
     "RSquared",
@@ -113,6 +114,7 @@ class RSquared(Metric):
         y_pred: np.ndarray,
         s_wght: Optional[np.ndarray] = None,
     ) -> None:
+        y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred)
         # Verify sample weights' shape, or set up 1-valued ones.
         s_wght = self._prepare_sample_weights(s_wght, n_samples=len(y_pred))
         # Update the residual sum of squares. wSSr = sum(w * (y - p)^2)
diff --git a/declearn/metrics/_utils.py b/declearn/metrics/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..757637a6d560883f18ab5fc931db4615deb17c47
--- /dev/null
+++ b/declearn/metrics/_utils.py
@@ -0,0 +1,54 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Backend utils for metrics' computations."""
+
+from typing import Tuple
+
+import numpy as np
+
+__all__ = [
+    "squeeze_into_identical_shapes",
+]
+
+
+def squeeze_into_identical_shapes(
+    y_true: np.ndarray,
+    y_pred: np.ndarray,
+) -> Tuple[np.ndarray, np.ndarray]:
+    """Verify that inputs have identical shapes, up to squeezable dims.
+
+    Return the input arrays, squeezed when needed.
+    Raise a ValueError if they cannot be made to match.
+    """
+    # Case of identical-shape inputs.
+    if y_true.shape == y_pred.shape:
+        return y_true, y_pred
+    # Case of identical-shape inputs up to squeezable dims.
+    y_true = np.squeeze(y_true)
+    y_pred = np.squeeze(y_pred)
+    if y_true.shape == y_pred.shape:
+        # Handle edge case of scalar values: preserve one dimension.
+        if not y_true.shape:
+            y_true = np.expand_dims(y_true, 0)
+            y_pred = np.expand_dims(y_pred, 0)
+        return y_true, y_pred
+    # Case of mismatching shapes.
+    raise ValueError(
+        "Received inputs with incompatible shapes: "
+        f"y_true has shape {y_true.shape}, y_pred has shape {y_pred.shape}."
+    )
diff --git a/test/metrics/test_mae_mse.py b/test/metrics/test_mae_mse.py
index cd55e880d6561de0d5b33636be8cb57e4f21fdb0..de472efdf1fa668752e61e26b14c6ff9b187a29c 100644
--- a/test/metrics/test_mae_mse.py
+++ b/test/metrics/test_mae_mse.py
@@ -24,7 +24,7 @@ import numpy as np
 import pytest
 
 from declearn.metrics import MeanAbsoluteError, MeanSquaredError, Metric
-from declearn.test_utils import make_importable
+from declearn.test_utils import assert_dict_equal, make_importable
 
 # relative imports from `metric_testing.py`
 with make_importable(os.path.dirname(__file__)):
@@ -69,7 +69,7 @@ class MeanMetricTestSuite(MetricTestSuite):
     """Unit tests suite for `MeanMetric` subclasses."""
 
     def test_update_errors(self, test_case: MetricTestCase) -> None:
-        """Test that `update` raises on improper `s_wght` shapes."""
+        """Test that `update` raises on improper input shapes."""
         metric = test_case.metric
         inputs = test_case.inputs
         # Test with multi-dimensional sample weights.
@@ -80,12 +80,34 @@ class MeanMetricTestSuite(MetricTestSuite):
         s_wght = np.ones(shape=(len(inputs["y_pred"]) + 2,))
         with pytest.raises(ValueError):
             metric.update(inputs["y_true"], inputs["y_pred"], s_wght)
+        # Test with mismatching-shape inputs.
+        y_true = inputs["y_true"]
+        y_pred = np.stack([inputs["y_pred"], inputs["y_pred"]], axis=-1)
+        with pytest.raises(ValueError):
+            metric.update(y_true, y_pred, s_wght)
 
     def test_zero_result(self, test_case: MetricTestCase) -> None:
         """Test that `get_results` works with zero-valued divisor."""
         metric = test_case.metric
         assert metric.get_result() == {metric.name: 0.0}
 
+    def test_update_expanded_shape(self, test_case: MetricTestCase) -> None:
+        """Test that the metric supports expanded-dim input predictions."""
+        # Gather states with basic inputs.
+        metric, inputs = test_case.metric, test_case.inputs
+        metric.update(**inputs)
+        states = metric.get_states()
+        metric.reset()
+        # Do the same with expanded-dim predictions.
+        metric.update(
+            inputs["y_true"],
+            np.expand_dims(inputs["y_pred"], -1),
+            inputs.get("s_wght"),
+        )
+        st_bis = metric.get_states()
+        # Verify that results are the same.
+        assert_dict_equal(states, st_bis)
+
 
 @pytest.mark.parametrize("weighted", [False, True], ids=["base", "weighted"])
 @pytest.mark.parametrize("case", ["mae"])