Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit a6fc2485 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Use new assertion utils in existing unit tests.

parent 0420531a
No related branches found
No related tags found
1 merge request!15Implement framework-specific OptiModule subclasses.
...@@ -17,16 +17,18 @@ ...@@ -17,16 +17,18 @@
"""Template test-suite for declearn Metric subclasses.""" """Template test-suite for declearn Metric subclasses."""
import json
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Union from typing import Dict, Optional, Union
import numpy as np import numpy as np
import pytest import pytest
from declearn.metrics import Metric from declearn.metrics import Metric
from declearn.utils import json_pack, json_unpack from declearn.test_utils import (
assert_dict_equal,
assert_json_serializable_dict,
)
@dataclass @dataclass
...@@ -60,23 +62,7 @@ class MetricTestCase: ...@@ -60,23 +62,7 @@ class MetricTestCase:
class MetricTestSuite: class MetricTestSuite:
"""Template for declearn Metric subclasses' unit tests suite.""" """Template for declearn Metric subclasses' unit tests suite."""
@staticmethod tol: Optional[float] = 0.0 # optional tolerance to scores' imperfection
def dicts_equal(
dicts_a: Dict[str, Union[float, np.ndarray]],
dicts_b: Dict[str, Union[float, np.ndarray]],
) -> bool:
"""Assert that two dicts dict are equal."""
try:
assert dicts_a.keys() == dicts_b.keys()
for key, v_a in dicts_a.items():
v_b = dicts_b[key]
if isinstance(v_a, np.ndarray):
assert np.all(v_a == v_b)
else:
assert v_a == v_b
except AssertionError:
return False
return True
def test_update(self, test_case: MetricTestCase) -> None: def test_update(self, test_case: MetricTestCase) -> None:
"""Test that the `update` method works as expected.""" """Test that the `update` method works as expected."""
...@@ -85,8 +71,9 @@ class MetricTestSuite: ...@@ -85,8 +71,9 @@ class MetricTestSuite:
metric.update(**test_case.inputs) metric.update(**test_case.inputs)
after = metric.get_states() after = metric.get_states()
assert before.keys() == after.keys() assert before.keys() == after.keys()
assert not self.dicts_equal(before, after) with pytest.raises(AssertionError): # assert not equal
assert self.dicts_equal(after, test_case.states) assert_dict_equal(before, after)
assert_dict_equal(after, test_case.states)
def test_zero_results(self, test_case: MetricTestCase) -> None: def test_zero_results(self, test_case: MetricTestCase) -> None:
"""Test that `get_result` works for un-updated metrics.""" """Test that `get_result` works for un-updated metrics."""
...@@ -101,7 +88,7 @@ class MetricTestSuite: ...@@ -101,7 +88,7 @@ class MetricTestSuite:
metric.update(**test_case.inputs) metric.update(**test_case.inputs)
result = metric.get_result() result = metric.get_result()
scores = test_case.scores scores = test_case.scores
assert self.dicts_equal(result, scores) assert_dict_equal(result, scores, np_tolerance=self.tol)
def test_reset(self, test_case: MetricTestCase) -> None: def test_reset(self, test_case: MetricTestCase) -> None:
"""Test that the `reset` method works as expected.""" """Test that the `reset` method works as expected."""
...@@ -110,7 +97,7 @@ class MetricTestSuite: ...@@ -110,7 +97,7 @@ class MetricTestSuite:
metric.update(**test_case.inputs) metric.update(**test_case.inputs)
metric.reset() metric.reset()
after = metric.get_states() after = metric.get_states()
assert self.dicts_equal(before, after) assert_dict_equal(before, after)
def test_aggreg(self, test_case: MetricTestCase) -> None: def test_aggreg(self, test_case: MetricTestCase) -> None:
"""Test that the `agg_states` method works as expected.""" """Test that the `agg_states` method works as expected."""
...@@ -120,15 +107,16 @@ class MetricTestSuite: ...@@ -120,15 +107,16 @@ class MetricTestSuite:
metric.update(**test_case.inputs) metric.update(**test_case.inputs)
metbis.update(**test_case.inputs) metbis.update(**test_case.inputs)
# Aggregate the second into the first. Verify that they now differ. # Aggregate the second into the first. Verify that they now differ.
assert self.dicts_equal(metric.get_states(), metbis.get_states()) assert_dict_equal(metric.get_states(), metbis.get_states())
metbis.agg_states(metric.get_states()) metbis.agg_states(metric.get_states())
assert self.dicts_equal(metric.get_states(), test_case.states) assert_dict_equal(metric.get_states(), test_case.states)
assert not self.dicts_equal(metric.get_states(), metbis.get_states()) with pytest.raises(AssertionError): # assert not equal
assert_dict_equal(metric.get_states(), metbis.get_states())
# Verify the correctness of the aggregated states and scores. # Verify the correctness of the aggregated states and scores.
states = test_case.agg_states states = test_case.agg_states
scores = test_case.agg_scores scores = test_case.agg_scores
assert self.dicts_equal(metbis.get_states(), states) assert_dict_equal(metbis.get_states(), states)
assert self.dicts_equal(metbis.get_result(), scores) assert_dict_equal(metbis.get_result(), scores, np_tolerance=self.tol)
def test_aggreg_errors(self, test_case: MetricTestCase) -> None: def test_aggreg_errors(self, test_case: MetricTestCase) -> None:
"""Test that the `agg_states` method raises expected exceptions.""" """Test that the `agg_states` method raises expected exceptions."""
...@@ -170,28 +158,21 @@ class MetricTestSuite: ...@@ -170,28 +158,21 @@ class MetricTestSuite:
metric.reset() metric.reset()
metric.update(**inpbis) metric.update(**inpbis)
st_bis = metric.get_states() st_bis = metric.get_states()
assert self.dicts_equal(states, st_bis) assert_dict_equal(states, st_bis)
def test_config(self, test_case: MetricTestCase) -> None: def test_config(self, test_case: MetricTestCase) -> None:
"""Test that the metric supports (de)serialization from a dict.""" """Test that the metric supports (de)serialization from a dict."""
metric = test_case.metric metric = test_case.metric
# Test that `get_config` returns a JSON-serializable dict. # Test that `get_config` returns a JSON-serializable dict.
config = metric.get_config() config = metric.get_config()
assert isinstance(config, dict) assert_json_serializable_dict(config)
json_c = json.dumps(config, default=json_pack)
cfgbis = json.loads(json_c, object_hook=json_unpack)
for key, val in config.items():
# Adjust for tuple-list JSON conversion.
if isinstance(val, tuple):
cfgbis[key] = tuple(cfgbis[key])
assert cfgbis == config
# Test that `from_config` produces a similar Metric. # Test that `from_config` produces a similar Metric.
metbis = type(metric).from_config(config) metbis = type(metric).from_config(config)
assert isinstance(metbis, type(metric)) assert isinstance(metbis, type(metric))
cfgbis = metbis.get_config() cfgbis = metbis.get_config()
assert cfgbis == config assert_dict_equal(cfgbis, config)
# Test that `from_specs` works properly as well. # Test that `from_specs` works properly as well.
metter = Metric.from_specs(metric.name, config) metter = Metric.from_specs(metric.name, config)
assert isinstance(metter, type(metric)) assert isinstance(metter, type(metric))
cfgter = metter.get_config() cfgter = metter.get_config()
assert cfgter == config assert_dict_equal(cfgter, config)
...@@ -72,21 +72,7 @@ def test_case_fixture( ...@@ -72,21 +72,7 @@ def test_case_fixture(
class TestRSquared(MeanMetricTestSuite): class TestRSquared(MeanMetricTestSuite):
"""Unit tests for `RSquared` Metric.""" """Unit tests for `RSquared` Metric."""
@staticmethod tol = 1e-12 # allow declearn and sklearn scores to differ at 10^-12 prec.
def dicts_equal(
dicts_a: Dict[str, Union[float, np.ndarray]],
dicts_b: Dict[str, Union[float, np.ndarray]],
) -> bool:
# Override the base behaviour: allow for very small (10^-12)
# numerical imprecisions in values' equality assertions.
try:
assert dicts_a.keys() == dicts_b.keys()
for key, v_a in dicts_a.items():
v_b = dicts_b[key]
assert np.allclose(v_a, v_b, rtol=0, atol=1e-12)
except AssertionError:
return False
return True
def test_zero_result(self, test_case: MetricTestCase) -> None: def test_zero_result(self, test_case: MetricTestCase) -> None:
"""Test that `get_results` works with zero-valued divisor.""" """Test that `get_results` works with zero-valued divisor."""
......
...@@ -42,6 +42,7 @@ from declearn.optimizer.modules import NoiseModule, OptiModule ...@@ -42,6 +42,7 @@ from declearn.optimizer.modules import NoiseModule, OptiModule
from declearn.test_utils import ( from declearn.test_utils import (
FrameworkType, FrameworkType,
GradientsTestCase, GradientsTestCase,
assert_dict_equal,
assert_json_serializable_dict, assert_json_serializable_dict,
) )
from declearn.utils import access_types_mapping, set_device_policy from declearn.utils import access_types_mapping, set_device_policy
...@@ -104,7 +105,7 @@ class TestOptiModule(PluginTestBase): ...@@ -104,7 +105,7 @@ class TestOptiModule(PluginTestBase):
test_case = GradientsTestCase(framework) test_case = GradientsTestCase(framework)
module.run(test_case.mock_gradient) module.run(test_case.mock_gradient)
module.set_state(initial) module.set_state(initial)
assert module.get_state() == initial assert_dict_equal(module.get_state(), initial)
def test_set_state_updated( def test_set_state_updated(
self, cls: Type[OptiModule], framework: FrameworkType self, cls: Type[OptiModule], framework: FrameworkType
...@@ -117,7 +118,7 @@ class TestOptiModule(PluginTestBase): ...@@ -117,7 +118,7 @@ class TestOptiModule(PluginTestBase):
states = module.get_state() states = module.get_state()
module = cls() module = cls()
module.set_state(states) module.set_state(states)
assert module.get_state() == states assert_dict_equal(module.get_state(), states)
def test_set_state_results( def test_set_state_results(
self, cls: Type[OptiModule], framework: FrameworkType self, cls: Type[OptiModule], framework: FrameworkType
......
...@@ -139,11 +139,6 @@ class TestOptimizer: ...@@ -139,11 +139,6 @@ class TestOptimizer:
modules=[MockOptiModule(arg="optimodule")], modules=[MockOptiModule(arg="optimodule")],
) )
config = optimizer.get_config() config = optimizer.get_config()
assert isinstance(config, dict)
# Hack around the config dict to account for JSON converting tuples.
config["regularizers"] = [list(e) for e in config["regularizers"]]
config["modules"] = [list(e) for e in config["modules"]]
# Run the JSON-serializability test.
assert_json_serializable_dict(config) assert_json_serializable_dict(config)
def test_from_config(self) -> None: def test_from_config(self) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment