Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 57419237 authored by BIGAUD Nathan's avatar BIGAUD Nathan
Browse files

Merge branch 'test-vector' into 'release-todo'

Adding unit tests for vectors

See merge request !22
parents 92e82dbf 9ee99bc8
No related branches found
No related tags found
2 merge requests!23Release version 2.0,!22Adding unit tests for vectors
Pipeline #751210 waiting for manual action
...@@ -25,7 +25,6 @@ from typing_extensions import Self # future: import from typing (Py>=3.11) ...@@ -25,7 +25,6 @@ from typing_extensions import Self # future: import from typing (Py>=3.11)
from declearn.model.api._vector import Vector, register_vector_type from declearn.model.api._vector import Vector, register_vector_type
__all__ = [ __all__ = [
"NumpyVector", "NumpyVector",
] ]
...@@ -69,10 +68,14 @@ class NumpyVector(Vector): ...@@ -69,10 +68,14 @@ class NumpyVector(Vector):
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
valid = isinstance(other, NumpyVector) valid = isinstance(other, NumpyVector)
valid = valid and (self.coefs.keys() == other.coefs.keys()) if valid:
return valid and all( valid = self.coefs.keys() == other.coefs.keys()
np.array_equal(self.coefs[k], other.coefs[k]) for k in self.coefs if valid:
) valid = all(
np.array_equal(self.coefs[k], other.coefs[k])
for k in self.coefs
)
return valid
def sign( def sign(
self, self,
......
...@@ -25,11 +25,12 @@ import tensorflow as tf # type: ignore ...@@ -25,11 +25,12 @@ import tensorflow as tf # type: ignore
from tensorflow.python.framework.ops import EagerTensor # type: ignore from tensorflow.python.framework.ops import EagerTensor # type: ignore
# pylint: enable=no-name-in-module # pylint: enable=no-name-in-module
from typing_extensions import Self # future: import from typing (Py>=3.11) from typing_extensions import Self # future: import from typing (Py>=3.11)
# fmt: on
from declearn.model.api import Vector, register_vector_type from declearn.model.api import Vector, register_vector_type
from declearn.model.sklearn import NumpyVector from declearn.model.sklearn import NumpyVector
# fmt: on
@register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices) @register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices)
class TensorflowVector(Vector): class TensorflowVector(Vector):
...@@ -131,11 +132,14 @@ class TensorflowVector(Vector): ...@@ -131,11 +132,14 @@ class TensorflowVector(Vector):
other: Any, other: Any,
) -> bool: ) -> bool:
valid = isinstance(other, TensorflowVector) valid = isinstance(other, TensorflowVector)
valid = valid & (self.coefs.keys() == other.coefs.keys()) if valid:
return valid and all( valid = self.coefs.keys() == other.coefs.keys()
self._tensor_equal(self.coefs[key], other.coefs[key]) if valid:
for key in self.coefs valid = all(
) self._tensor_equal(self.coefs[key], other.coefs[key])
for key in self.coefs
)
return valid
@staticmethod @staticmethod
def _tensor_equal( def _tensor_equal(
......
...@@ -111,11 +111,14 @@ class TorchVector(Vector): ...@@ -111,11 +111,14 @@ class TorchVector(Vector):
other: Any, other: Any,
) -> bool: ) -> bool:
valid = isinstance(other, TorchVector) valid = isinstance(other, TorchVector)
valid = valid and (self.coefs.keys() == other.coefs.keys()) if valid:
return valid and all( valid = self.coefs.keys() == other.coefs.keys()
np.array_equal(self.coefs[k].numpy(), other.coefs[k].numpy()) if valid:
for k in self.coefs valid = all(
) np.array_equal(self.coefs[k].numpy(), other.coefs[k].numpy())
for k in self.coefs
)
return valid
def sign(self) -> Self: # type: ignore def sign(self) -> Self: # type: ignore
# false-positive; pylint: disable=no-member # false-positive; pylint: disable=no-member
......
...@@ -28,7 +28,6 @@ from numpy.typing import ArrayLike ...@@ -28,7 +28,6 @@ from numpy.typing import ArrayLike
from declearn.model.api import Vector from declearn.model.api import Vector
from declearn.model.sklearn import NumpyVector from declearn.model.sklearn import NumpyVector
__all__ = [ __all__ = [
"FrameworkType", "FrameworkType",
"GradientsTestCase", "GradientsTestCase",
...@@ -116,14 +115,28 @@ class GradientsTestCase: ...@@ -116,14 +115,28 @@ class GradientsTestCase:
) )
@property @property
def mock_allzero_gradient(self) -> Vector: def mock_ones(self) -> Vector:
"""Instantiate a Vector with random-valued mock gradients. """Instantiate a Vector with random-valued mock gradients.
Note: the RNG used to generate gradients has a fixed seed, Note: the RNG used to generate gradients has a fixed seed,
to that gradients have the same values whatever the to that gradients have the same values whatever the
tensor framework used is. tensor framework used is.
""" """
shapes = [(64, 32), (32,), (32, 16), (16,), (16, 1), (1,)] shapes = [(5, 5), (4,), (1,)]
values = [np.ones(shape) for shape in shapes]
return self.vector_cls(
{str(idx): self.convert(value) for idx, value in enumerate(values)}
)
@property
def mock_zeros(self) -> Vector:
"""Instantiate a Vector with random-valued mock gradients.
Note: the RNG used to generate gradients has a fixed seed,
to that gradients have the same values whatever the
tensor framework used is.
"""
shapes = [(5, 5), (4,), (1,)]
values = [np.zeros(shape) for shape in shapes] values = [np.zeros(shape) for shape in shapes]
return self.vector_cls( return self.vector_cls(
{str(idx): self.convert(value) for idx, value in enumerate(values)} {str(idx): self.convert(value) for idx, value in enumerate(values)}
......
# coding: utf-8
"""Unit tests for Vector and its subclasses.
This test makes use of `declearn.test_utils.list_available_frameworks`
so as to modularly run a standard test suite on the available Vector
subclasses.
"""
import json
import numpy as np
import pytest
from declearn.test_utils import (
FrameworkType,
GradientsTestCase,
list_available_frameworks,
)
from declearn.utils import json_pack, json_unpack
@pytest.fixture(name="framework", params=list_available_frameworks())
def framework_fixture(request):
"""Fixture to provide with the name of a model framework."""
return request.param
class TestVectorAbstractMethods:
"""Test abstract methods."""
def test_sum(self, framework: FrameworkType) -> None:
"""Test coefficient-wise sum."""
grad = GradientsTestCase(framework)
ones = grad.mock_ones
test_coefs = ones.sum().coefs
test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
values = [25.0, 4.0, 1.0]
assert values == test_values
def test_max(self, framework: FrameworkType) -> None:
"""Test coef.-wise, element-wise maximum wrt to another Vector."""
grad = GradientsTestCase(framework)
ones, zeros = (grad.mock_ones, grad.mock_zeros)
values = [np.ones((5, 5)), np.ones((4,)), np.ones((1,))]
# test Vector
test_coefs = zeros.maximum(ones).coefs
test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
assert all(
(values[i] == test_values[i]).all() for i in range(len(values))
)
# test float
test_coefs = zeros.maximum(1.0).coefs
test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
assert all(
(values[i] == test_values[i]).all() for i in range(len(values))
)
def test_min(self, framework: FrameworkType) -> None:
"""Test coef.-wise, element-wise minimum wrt to another Vector."""
grad = GradientsTestCase(framework)
ones, zeros = (grad.mock_ones, grad.mock_zeros)
values = [np.zeros((5, 5)), np.zeros((4,)), np.zeros((1,))]
# test Vector
test_coefs = ones.minimum(zeros).coefs
test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
assert all(
(values[i] == test_values[i]).all() for i in range(len(values))
)
# test float
test_coefs = zeros.minimum(1.0).coefs
test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
assert all(
(values[i] == test_values[i]).all() for i in range(len(values))
)
def test_sign(self, framework: FrameworkType) -> None:
"""Test coefficient-wise sign check"""
grad = GradientsTestCase(framework)
ones = grad.mock_ones
for vec in ones, -1 * ones:
test_coefs = vec.sign().coefs
test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
values = [grad.to_numpy(vec.coefs[el]) for el in vec.coefs]
assert all(
(values[i] == test_values[i]).all() for i in range(len(values))
)
def test_eq(self, framework: FrameworkType) -> None:
"""Test __eq__ operator"""
grad = GradientsTestCase(framework)
ones, ones_bis, zeros = grad.mock_ones, grad.mock_ones, grad.mock_zeros
rand = grad.mock_gradient
assert ones == ones_bis
assert zeros != ones
assert ones != rand
assert 1.0 != ones
class TestVector:
"""Test non-abstract methods"""
def test_operator(self, framework: FrameworkType) -> None:
"Test all element-wise operators wiring"
grad = GradientsTestCase(framework)
def _get_sq_root_two(ones, zeros):
"""Returns the comaprison of a hardcoded sequence of operations
with its exptected result"""
values = [
el * (2 ** (1 / 2))
for el in [np.ones((5, 5)), np.ones((4,)), np.ones((1,))]
]
test_grad = (0 + (1.0 * ones + ones * 1.0) * ones / 1.0 - 0) ** (
(zeros + ones - zeros) / (ones + ones)
)
test_coefs = test_grad.coefs
test_values = [grad.to_numpy(test_coefs[el]) for el in test_coefs]
return all(
(values[i] == test_values[i]).all() for i in range(len(values))
)
ones, zeros = grad.mock_ones, grad.mock_zeros
assert _get_sq_root_two(ones, zeros)
assert _get_sq_root_two(ones, 0)
def test_pack(self, framework: FrameworkType) -> None:
"""Test that `Vector.pack` returns JSON-serializable results."""
grad = GradientsTestCase(framework)
ones = grad.mock_ones
packed = ones.pack()
# Check that the output is a dict with str keys.
assert isinstance(packed, dict)
assert all(isinstance(key, str) for key in packed)
# Check that the "packed" dict is JSON-serializable.
dump = json.dumps(packed, default=json_pack)
load = json.loads(dump, object_hook=json_unpack)
assert isinstance(load, dict)
assert load.keys() == packed.keys()
assert all(np.all(load[key] == packed[key]) for key in load)
def test_unpack(self, framework: FrameworkType) -> None:
"""Test that `Vector.unpack` counterparts `Vector.pack` adequately."""
grad = GradientsTestCase(framework)
ones = grad.mock_ones
packed = ones.pack()
test_vec = grad.vector_cls.unpack(packed)
assert test_vec == ones
def test_repr(self, framework: FrameworkType) -> None:
"""Test shape and dtypes together using __repr__"""
grad = GradientsTestCase(framework)
test_value = repr(grad.mock_ones)
value = grad.mock_ones.coefs["0"]
arr_type = f"{type(value).__module__}.{type(value).__name__}"
value = (
f"{grad.vector_cls.__name__} with 3 coefs:"
f"\n 0: float64 {arr_type} with shape (5, 5)"
f"\n 1: float64 {arr_type} with shape (4,)"
f"\n 2: float64 {arr_type} with shape (1,)"
)
assert test_value == value
def test_json_serialization(self, framework: FrameworkType) -> None:
"""Test that a Vector instance is JSON-serializable."""
vector = GradientsTestCase(framework).mock_gradient
dump = json.dumps(vector, default=json_pack)
loaded = json.loads(dump, object_hook=json_unpack)
assert isinstance(loaded, type(vector))
assert loaded == vector
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment