From 9ee99bc8c4db1292f5d80d72040df6fc07dd26fe Mon Sep 17 00:00:00 2001
From: BIGAUD Nathan <nathan.bigaud@inria.fr>
Date: Mon, 6 Feb 2023 15:40:43 +0000
Subject: [PATCH] Adding unit tests for vectors

---
 declearn/model/sklearn/_np_vec.py    |  13 +-
 declearn/model/tensorflow/_vector.py |  16 ++-
 declearn/model/torch/_vector.py      |  13 +-
 declearn/test_utils/_vectors.py      |  19 ++-
 test/model/test_vector.py            | 170 +++++++++++++++++++++++++++
 5 files changed, 212 insertions(+), 19 deletions(-)
 create mode 100644 test/model/test_vector.py

diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py
index 8bd21e6a..75d1d42b 100644
--- a/declearn/model/sklearn/_np_vec.py
+++ b/declearn/model/sklearn/_np_vec.py
@@ -25,7 +25,6 @@ from typing_extensions import Self  # future: import from typing (Py>=3.11)
 
 from declearn.model.api._vector import Vector, register_vector_type
 
-
 __all__ = [
     "NumpyVector",
 ]
@@ -69,10 +68,14 @@ class NumpyVector(Vector):
 
     def __eq__(self, other: Any) -> bool:
         valid = isinstance(other, NumpyVector)
-        valid = valid and (self.coefs.keys() == other.coefs.keys())
-        return valid and all(
-            np.array_equal(self.coefs[k], other.coefs[k]) for k in self.coefs
-        )
+        if valid:
+            valid = self.coefs.keys() == other.coefs.keys()
+        if valid:
+            valid = all(
+                np.array_equal(self.coefs[k], other.coefs[k])
+                for k in self.coefs
+            )
+        return valid
 
     def sign(
         self,
diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py
index c8c27a5a..e27cc572 100644
--- a/declearn/model/tensorflow/_vector.py
+++ b/declearn/model/tensorflow/_vector.py
@@ -25,11 +25,12 @@ import tensorflow as tf  # type: ignore
 from tensorflow.python.framework.ops import EagerTensor  # type: ignore
 # pylint: enable=no-name-in-module
 from typing_extensions import Self  # future: import from typing (Py>=3.11)
-# fmt: on
 
 from declearn.model.api import Vector, register_vector_type
 from declearn.model.sklearn import NumpyVector
 
+# fmt: on
+
 
 @register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices)
 class TensorflowVector(Vector):
@@ -131,11 +132,14 @@ class TensorflowVector(Vector):
         other: Any,
     ) -> bool:
         valid = isinstance(other, TensorflowVector)
-        valid = valid & (self.coefs.keys() == other.coefs.keys())
-        return valid and all(
-            self._tensor_equal(self.coefs[key], other.coefs[key])
-            for key in self.coefs
-        )
+        if valid:
+            valid = self.coefs.keys() == other.coefs.keys()
+        if valid:
+            valid = all(
+                self._tensor_equal(self.coefs[key], other.coefs[key])
+                for key in self.coefs
+            )
+        return valid
 
     @staticmethod
     def _tensor_equal(
diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py
index 58ca0a3b..25470995 100644
--- a/declearn/model/torch/_vector.py
+++ b/declearn/model/torch/_vector.py
@@ -111,11 +111,14 @@ class TorchVector(Vector):
         other: Any,
     ) -> bool:
         valid = isinstance(other, TorchVector)
-        valid = valid and (self.coefs.keys() == other.coefs.keys())
-        return valid and all(
-            np.array_equal(self.coefs[k].numpy(), other.coefs[k].numpy())
-            for k in self.coefs
-        )
+        if valid:
+            valid = self.coefs.keys() == other.coefs.keys()
+        if valid:
+            valid = all(
+                np.array_equal(self.coefs[k].numpy(), other.coefs[k].numpy())
+                for k in self.coefs
+            )
+        return valid
 
     def sign(self) -> Self:  # type: ignore
         # false-positive; pylint: disable=no-member
diff --git a/declearn/test_utils/_vectors.py b/declearn/test_utils/_vectors.py
index d888c962..d789cec4 100644
--- a/declearn/test_utils/_vectors.py
+++ b/declearn/test_utils/_vectors.py
@@ -28,7 +28,6 @@ from numpy.typing import ArrayLike
 from declearn.model.api import Vector
 from declearn.model.sklearn import NumpyVector
 
-
 __all__ = [
     "FrameworkType",
     "GradientsTestCase",
@@ -116,14 +115,28 @@ class GradientsTestCase:
         )
 
     @property
-    def mock_allzero_gradient(self) -> Vector:
+    def mock_ones(self) -> Vector:
         """Instantiate a Vector with random-valued mock gradients.
 
         Note: the RNG used to generate gradients has a fixed seed,
                 to that gradients have the same values whatever the
                 tensor framework used is.
         """
-        shapes = [(64, 32), (32,), (32, 16), (16,), (16, 1), (1,)]
+        shapes = [(5, 5), (4,), (1,)]
+        values = [np.ones(shape) for shape in shapes]
+        return self.vector_cls(
+            {str(idx): self.convert(value) for idx, value in enumerate(values)}
+        )
+
+    @property
+    def mock_zeros(self) -> Vector:
+        """Instantiate a Vector with random-valued mock gradients.
+
+        Note: the RNG used to generate gradients has a fixed seed,
+                to that gradients have the same values whatever the
+                tensor framework used is.
+        """
+        shapes = [(5, 5), (4,), (1,)]
         values = [np.zeros(shape) for shape in shapes]
         return self.vector_cls(
             {str(idx): self.convert(value) for idx, value in enumerate(values)}
diff --git a/test/model/test_vector.py b/test/model/test_vector.py
new file mode 100644
index 00000000..2b46f002
--- /dev/null
+++ b/test/model/test_vector.py
@@ -0,0 +1,170 @@
+# coding: utf-8
+
+"""Unit tests for Vector and its subclasses.
+
+This test makes use of `declearn.test_utils.list_available_frameworks`
+so as to modularly run a standard test suite on the available Vector
+subclasses.
+"""
+
+import json
+
+import numpy as np
+import pytest
+
+from declearn.test_utils import (
+    FrameworkType,
+    GradientsTestCase,
+    list_available_frameworks,
+)
+from declearn.utils import json_pack, json_unpack
+
+
+@pytest.fixture(name="framework", params=list_available_frameworks())
+def framework_fixture(request):
+    """Fixture to provide with the name of a model framework."""
+    return request.param
+
+
+class TestVectorAbstractMethods:
+    """Test abstract methods."""
+
+    def test_sum(self, framework: FrameworkType) -> None:
+        """Test coefficient-wise sum."""
+        grad = GradientsTestCase(framework)
+        ones = grad.mock_ones
+        test_coefs = ones.sum().coefs
+        test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
+        values = [25.0, 4.0, 1.0]
+        assert values == test_values
+
+    def test_max(self, framework: FrameworkType) -> None:
+        """Test coef.-wise, element-wise maximum wrt to another Vector."""
+        grad = GradientsTestCase(framework)
+        ones, zeros = (grad.mock_ones, grad.mock_zeros)
+        values = [np.ones((5, 5)), np.ones((4,)), np.ones((1,))]
+        # test Vector
+        test_coefs = zeros.maximum(ones).coefs
+        test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
+        assert all(
+            (values[i] == test_values[i]).all() for i in range(len(values))
+        )
+        # test float
+        test_coefs = zeros.maximum(1.0).coefs
+        test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
+        assert all(
+            (values[i] == test_values[i]).all() for i in range(len(values))
+        )
+
+    def test_min(self, framework: FrameworkType) -> None:
+        """Test coef.-wise, element-wise minimum wrt to another Vector."""
+        grad = GradientsTestCase(framework)
+        ones, zeros = (grad.mock_ones, grad.mock_zeros)
+        values = [np.zeros((5, 5)), np.zeros((4,)), np.zeros((1,))]
+        # test Vector
+        test_coefs = ones.minimum(zeros).coefs
+        test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
+        assert all(
+            (values[i] == test_values[i]).all() for i in range(len(values))
+        )
+        # test float
+        test_coefs = zeros.minimum(1.0).coefs
+        test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
+        assert all(
+            (values[i] == test_values[i]).all() for i in range(len(values))
+        )
+
+    def test_sign(self, framework: FrameworkType) -> None:
+        """Test coefficient-wise sign check"""
+        grad = GradientsTestCase(framework)
+        ones = grad.mock_ones
+        for vec in ones, -1 * ones:
+            test_coefs = vec.sign().coefs
+            test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
+            values = [grad.to_numpy(vec.coefs[el]) for el in vec.coefs]
+            assert all(
+                (values[i] == test_values[i]).all() for i in range(len(values))
+            )
+
+    def test_eq(self, framework: FrameworkType) -> None:
+        """Test __eq__ operator"""
+        grad = GradientsTestCase(framework)
+        ones, ones_bis, zeros = grad.mock_ones, grad.mock_ones, grad.mock_zeros
+        rand = grad.mock_gradient
+        assert ones == ones_bis
+        assert zeros != ones
+        assert ones != rand
+        assert 1.0 != ones
+
+
+class TestVector:
+    """Test non-abstract methods"""
+
+    def test_operator(self, framework: FrameworkType) -> None:
+        "Test all element-wise operators wiring"
+        grad = GradientsTestCase(framework)
+
+        def _get_sq_root_two(ones, zeros):
+            """Returns the comaprison of a hardcoded sequence of operations
+            with its exptected result"""
+            values = [
+                el * (2 ** (1 / 2))
+                for el in [np.ones((5, 5)), np.ones((4,)), np.ones((1,))]
+            ]
+            test_grad = (0 + (1.0 * ones + ones * 1.0) * ones / 1.0 - 0) ** (
+                (zeros + ones - zeros) / (ones + ones)
+            )
+            test_coefs = test_grad.coefs
+            test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
+            return all(
+                (values[i] == test_values[i]).all() for i in range(len(values))
+            )
+
+        ones, zeros = grad.mock_ones, grad.mock_zeros
+        assert _get_sq_root_two(ones, zeros)
+        assert _get_sq_root_two(ones, 0)
+
+    def test_pack(self, framework: FrameworkType) -> None:
+        """Test that `Vector.pack` returns JSON-serializable results."""
+        grad = GradientsTestCase(framework)
+        ones = grad.mock_ones
+        packed = ones.pack()
+        # Check that the output is a dict with str keys.
+        assert isinstance(packed, dict)
+        assert all(isinstance(key, str) for key in packed)
+        # Check that the "packed" dict is JSON-serializable.
+        dump = json.dumps(packed, default=json_pack)
+        load = json.loads(dump, object_hook=json_unpack)
+        assert isinstance(load, dict)
+        assert load.keys() == packed.keys()
+        assert all(np.all(load[key] == packed[key]) for key in load)
+
+    def test_unpack(self, framework: FrameworkType) -> None:
+        """Test that `Vector.unpack` counterparts `Vector.pack` adequately."""
+        grad = GradientsTestCase(framework)
+        ones = grad.mock_ones
+        packed = ones.pack()
+        test_vec = grad.vector_cls.unpack(packed)
+        assert test_vec == ones
+
+    def test_repr(self, framework: FrameworkType) -> None:
+        """Test shape and dtypes together using __repr__"""
+        grad = GradientsTestCase(framework)
+        test_value = repr(grad.mock_ones)
+        value = grad.mock_ones.coefs["0"]
+        arr_type = f"{type(value).__module__}.{type(value).__name__}"
+        value = (
+            f"{grad.vector_cls.__name__} with 3 coefs:"
+            f"\n    0: float64 {arr_type} with shape (5, 5)"
+            f"\n    1: float64 {arr_type} with shape (4,)"
+            f"\n    2: float64 {arr_type} with shape (1,)"
+        )
+        assert test_value == value
+
+    def test_json_serialization(self, framework: FrameworkType) -> None:
+        """Test that a Vector instance is JSON-serializable."""
+        vector = GradientsTestCase(framework).mock_gradient
+        dump = json.dumps(vector, default=json_pack)
+        loaded = json.loads(dump, object_hook=json_unpack)
+        assert isinstance(loaded, type(vector))
+        assert loaded == vector
-- 
GitLab