diff --git a/test/metrics/metric_testing.py b/test/metrics/metric_testing.py index 4141465093b268eb348872971f4cda30166a22d1..11c48ace1bdfc069cebe7f05e82326dc38907deb 100644 --- a/test/metrics/metric_testing.py +++ b/test/metrics/metric_testing.py @@ -17,16 +17,18 @@ """Template test-suite for declearn Metric subclasses.""" -import json from copy import deepcopy from dataclasses import dataclass -from typing import Dict, Union +from typing import Dict, Optional, Union import numpy as np import pytest from declearn.metrics import Metric -from declearn.utils import json_pack, json_unpack +from declearn.test_utils import ( + assert_dict_equal, + assert_json_serializable_dict, +) @dataclass @@ -60,23 +62,7 @@ class MetricTestCase: class MetricTestSuite: """Template for declearn Metric subclasses' unit tests suite.""" - @staticmethod - def dicts_equal( - dicts_a: Dict[str, Union[float, np.ndarray]], - dicts_b: Dict[str, Union[float, np.ndarray]], - ) -> bool: - """Assert that two dicts dict are equal.""" - try: - assert dicts_a.keys() == dicts_b.keys() - for key, v_a in dicts_a.items(): - v_b = dicts_b[key] - if isinstance(v_a, np.ndarray): - assert np.all(v_a == v_b) - else: - assert v_a == v_b - except AssertionError: - return False - return True + tol: Optional[float] = 0.0 # optional tolerance to scores' imperfection def test_update(self, test_case: MetricTestCase) -> None: """Test that the `update` method works as expected.""" @@ -85,8 +71,9 @@ class MetricTestSuite: metric.update(**test_case.inputs) after = metric.get_states() assert before.keys() == after.keys() - assert not self.dicts_equal(before, after) - assert self.dicts_equal(after, test_case.states) + with pytest.raises(AssertionError): # assert not equal + assert_dict_equal(before, after) + assert_dict_equal(after, test_case.states) def test_zero_results(self, test_case: MetricTestCase) -> None: """Test that `get_result` works for un-updated metrics.""" @@ -101,7 +88,7 @@ class MetricTestSuite: metric.update(**test_case.inputs) result = metric.get_result() scores = test_case.scores - assert self.dicts_equal(result, scores) + assert_dict_equal(result, scores, np_tolerance=self.tol) def test_reset(self, test_case: MetricTestCase) -> None: """Test that the `reset` method works as expected.""" @@ -110,7 +97,7 @@ class MetricTestSuite: metric.update(**test_case.inputs) metric.reset() after = metric.get_states() - assert self.dicts_equal(before, after) + assert_dict_equal(before, after) def test_aggreg(self, test_case: MetricTestCase) -> None: """Test that the `agg_states` method works as expected.""" @@ -120,15 +107,16 @@ class MetricTestSuite: metric.update(**test_case.inputs) metbis.update(**test_case.inputs) # Aggregate the second into the first. Verify that they now differ. - assert self.dicts_equal(metric.get_states(), metbis.get_states()) + assert_dict_equal(metric.get_states(), metbis.get_states()) metbis.agg_states(metric.get_states()) - assert self.dicts_equal(metric.get_states(), test_case.states) - assert not self.dicts_equal(metric.get_states(), metbis.get_states()) + assert_dict_equal(metric.get_states(), test_case.states) + with pytest.raises(AssertionError): # assert not equal + assert_dict_equal(metric.get_states(), metbis.get_states()) # Verify the correctness of the aggregated states and scores. states = test_case.agg_states scores = test_case.agg_scores - assert self.dicts_equal(metbis.get_states(), states) - assert self.dicts_equal(metbis.get_result(), scores) + assert_dict_equal(metbis.get_states(), states) + assert_dict_equal(metbis.get_result(), scores, np_tolerance=self.tol) def test_aggreg_errors(self, test_case: MetricTestCase) -> None: """Test that the `agg_states` method raises expected exceptions.""" @@ -170,28 +158,21 @@ class MetricTestSuite: metric.reset() metric.update(**inpbis) st_bis = metric.get_states() - assert self.dicts_equal(states, st_bis) + assert_dict_equal(states, st_bis) def test_config(self, test_case: MetricTestCase) -> None: """Test that the metric supports (de)serialization from a dict.""" metric = test_case.metric # Test that `get_config` returns a JSON-serializable dict. config = metric.get_config() - assert isinstance(config, dict) - json_c = json.dumps(config, default=json_pack) - cfgbis = json.loads(json_c, object_hook=json_unpack) - for key, val in config.items(): - # Adjust for tuple-list JSON conversion. - if isinstance(val, tuple): - cfgbis[key] = tuple(cfgbis[key]) - assert cfgbis == config + assert_json_serializable_dict(config) # Test that `from_config` produces a similar Metric. metbis = type(metric).from_config(config) assert isinstance(metbis, type(metric)) cfgbis = metbis.get_config() - assert cfgbis == config + assert_dict_equal(cfgbis, config) # Test that `from_specs` works properly as well. metter = Metric.from_specs(metric.name, config) assert isinstance(metter, type(metric)) cfgter = metter.get_config() - assert cfgter == config + assert_dict_equal(cfgter, config) diff --git a/test/metrics/test_rsquared.py b/test/metrics/test_rsquared.py index 56550b7fa45372057e7c5d8990fa502ebfb1c2b3..ccf080a4310fadd1556b52f860b8800d3c521b62 100644 --- a/test/metrics/test_rsquared.py +++ b/test/metrics/test_rsquared.py @@ -72,21 +72,7 @@ def test_case_fixture( class TestRSquared(MeanMetricTestSuite): """Unit tests for `RSquared` Metric.""" - @staticmethod - def dicts_equal( - dicts_a: Dict[str, Union[float, np.ndarray]], - dicts_b: Dict[str, Union[float, np.ndarray]], - ) -> bool: - # Override the base behaviour: allow for very small (10^-12) - # numerical imprecisions in values' equality assertions. - try: - assert dicts_a.keys() == dicts_b.keys() - for key, v_a in dicts_a.items(): - v_b = dicts_b[key] - assert np.allclose(v_a, v_b, rtol=0, atol=1e-12) - except AssertionError: - return False - return True + tol = 1e-12 # allow declearn and sklearn scores to differ at 10^-12 prec. def test_zero_result(self, test_case: MetricTestCase) -> None: """Test that `get_results` works with zero-valued divisor.""" diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py index d6a1db0e28f89cfe78fe4bf792b12cabb24c1d7a..b79711570690755a156fd5384cebbf9bbc9638d1 100644 --- a/test/optimizer/test_modules.py +++ b/test/optimizer/test_modules.py @@ -42,6 +42,7 @@ from declearn.optimizer.modules import NoiseModule, OptiModule from declearn.test_utils import ( FrameworkType, GradientsTestCase, + assert_dict_equal, assert_json_serializable_dict, ) from declearn.utils import access_types_mapping, set_device_policy @@ -104,7 +105,7 @@ class TestOptiModule(PluginTestBase): test_case = GradientsTestCase(framework) module.run(test_case.mock_gradient) module.set_state(initial) - assert module.get_state() == initial + assert_dict_equal(module.get_state(), initial) def test_set_state_updated( self, cls: Type[OptiModule], framework: FrameworkType @@ -117,7 +118,7 @@ class TestOptiModule(PluginTestBase): states = module.get_state() module = cls() module.set_state(states) - assert module.get_state() == states + assert_dict_equal(module.get_state(), states) def test_set_state_results( self, cls: Type[OptiModule], framework: FrameworkType diff --git a/test/optimizer/test_optimizer.py b/test/optimizer/test_optimizer.py index babdf50057b30a5d930bb7cde2824166b14f3143..30eb2c918848aeb7f0051f77017f403f489147e1 100644 --- a/test/optimizer/test_optimizer.py +++ b/test/optimizer/test_optimizer.py @@ -139,11 +139,6 @@ class TestOptimizer: modules=[MockOptiModule(arg="optimodule")], ) config = optimizer.get_config() - assert isinstance(config, dict) - # Hack around the config dict to account for JSON converting tuples. - config["regularizers"] = [list(e) for e in config["regularizers"]] - config["modules"] = [list(e) for e in config["modules"]] - # Run the JSON-serializability test. assert_json_serializable_dict(config) def test_from_config(self) -> None: