From b3b75a7aa970c15a383ccdf306676e23f492db16 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 3 Feb 2023 17:38:39 +0100 Subject: [PATCH] Update black to version 23. --- declearn/main/utils/_training.py | 1 + declearn/model/api/_vector.py | 1 + declearn/utils/_dataclass.py | 1 + pyproject.toml | 2 +- test/communication/test_routines.py | 2 ++ test/metrics/test_metricset.py | 6 +++--- test/utils/test_json.py | 2 ++ test/utils/test_register.py | 6 ++++++ test/utils/test_serialize.py | 1 + 9 files changed, 18 insertions(+), 4 deletions(-) diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py index be142608..ff8b816f 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/main/utils/_training.py @@ -92,6 +92,7 @@ class TrainingManager: ) -> Metric: """Return an ad-hoc Metric object to compute the model's loss.""" loss_fn = self.model.loss_function + # Write a custom, unregistered Metric subclass. class LossMetric(MeanMetric, register=False): """Ad hoc Metric wrapping a model's loss function.""" diff --git a/declearn/model/api/_vector.py b/declearn/model/api/_vector.py index 46ac0a3b..f660c5f6 100644 --- a/declearn/model/api/_vector.py +++ b/declearn/model/api/_vector.py @@ -347,6 +347,7 @@ def register_vector_type( as a class decorator. """ v_types = (v_type, *types) + # Set up a registration function. def register(cls: Type[Vector]) -> Type[Vector]: nonlocal name, v_types diff --git a/declearn/utils/_dataclass.py b/declearn/utils/_dataclass.py index 1a481601..bf1c01b9 100644 --- a/declearn/utils/_dataclass.py +++ b/declearn/utils/_dataclass.py @@ -151,6 +151,7 @@ def dataclass_from_init( args_field = param.name if param.kind is param.VAR_KEYWORD: kwargs_field = param.name + # Add a method to instantiate from the dataclass. def instantiate(self) -> cls: # type: ignore """Instantiate from the wrapped init parameters.""" diff --git a/pyproject.toml b/pyproject.toml index 4296a8e2..ff128d84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ websockets = [ ] tests = [ # test-specific dependencies - "black ~= 22.0", + "black ~= 23.0", "mypy >= 0.930", "pylint >= 2.14", "pytest >= 6.1", diff --git a/test/communication/test_routines.py b/test/communication/test_routines.py index 3f04cd1d..ab66f86f 100644 --- a/test/communication/test_routines.py +++ b/test/communication/test_routines.py @@ -131,6 +131,7 @@ def _build_server_func( "certificate": ssl_cert["server_cert"] if use_ssl else None, "private_key": ssl_cert["server_pkey"] if use_ssl else None, } # type: Dict[str, Any] + # Define a coroutine that spawns and runs a server. async def server_coroutine() -> None: """Spawn a client and run `server_routine` in its context.""" @@ -158,6 +159,7 @@ def _build_client_funcs( server_uri = "localhost:8765" if protocol == "websockets": server_uri = f"ws{'s' * use_ssl}://{server_uri}" + # Define a coroutine that spawns and runs a client. async def client_coroutine( name: str, diff --git a/test/metrics/test_metricset.py b/test/metrics/test_metricset.py index 62cedf71..15637686 100644 --- a/test/metrics/test_metricset.py +++ b/test/metrics/test_metricset.py @@ -11,9 +11,9 @@ import pytest from declearn.metrics import MeanAbsoluteError, MeanSquaredError, MetricSet -def get_mock_metricset() -> Tuple[ - MeanAbsoluteError, MeanSquaredError, MetricSet -]: +def get_mock_metricset() -> ( + Tuple[MeanAbsoluteError, MeanSquaredError, MetricSet] +): """Provide with a MetricSet wrapping mock metrics.""" mae = mock.create_autospec(MeanAbsoluteError, instance=True) mae.name = MeanAbsoluteError.name diff --git a/test/utils/test_json.py b/test/utils/test_json.py index b17e2d9c..1e29351a 100644 --- a/test/utils/test_json.py +++ b/test/utils/test_json.py @@ -46,6 +46,7 @@ def test_add_json_support() -> None: not that the associated mechanics perform well. These are tested in `test_json_pack` and `test_json_unpack_known`. """ + # Declare a second, empty custom type for this test only. class OtherType: # pylint: disable=all pass @@ -66,6 +67,7 @@ def test_add_json_support() -> None: def test_json_pack() -> None: """Unit tests for `json_pack` with custom-specified objects.""" + # Define a subtype of CustomType (to ensure it is not supported). class SubType(CustomType): # pylint: disable=all pass diff --git a/test/utils/test_register.py b/test/utils/test_register.py index 302a04eb..575f760a 100644 --- a/test/utils/test_register.py +++ b/test/utils/test_register.py @@ -25,6 +25,7 @@ def test_create_types_registry() -> None: def test_register_type() -> None: """Unit tests for 'register_type' using valid instructions.""" + # Define mock custom classes. class BaseClass: # pylint: disable=all pass @@ -38,6 +39,7 @@ def test_register_type() -> None: assert register_type(BaseClass, name="base", group=group) is BaseClass # Register ChildClass. assert register_type(ChildClass, name="child", group=group) is ChildClass + # Register another BaseClass-inheriting class using decorator syntax. @register_type(name="other", group=group) class OtherChild(BaseClass): @@ -46,6 +48,7 @@ def test_register_type() -> None: def test_register_type_fails() -> None: """Unit tests for 'register_type' using invalid instructions.""" + # Define mock custom classes. class BaseClass: # pylint: disable=all pass @@ -69,6 +72,7 @@ def test_register_type_fails() -> None: def test_access_registered() -> None: """Unit tests for 'access_registered'.""" + # Define a mock custom class. class Class: # pylint: disable=all pass @@ -90,6 +94,7 @@ def test_access_registered() -> None: def test_access_registeration_info() -> None: """Unit tests for 'access_registration_info'.""" + # Define a pair of mock custom class. class Class_1: # pylint: disable=all pass @@ -116,6 +121,7 @@ def test_access_registeration_info() -> None: def test_access_types_mapping() -> None: """Unit tests for 'access_types_mapping'.""" group = f"test_{time.time_ns()}" + # Define mock custom type-registered classes. @register_type(name="base", group=group) @create_types_registry(name=group) diff --git a/test/utils/test_serialize.py b/test/utils/test_serialize.py index 468ecca9..62a8e9f5 100644 --- a/test/utils/test_serialize.py +++ b/test/utils/test_serialize.py @@ -44,6 +44,7 @@ class MockClass: @pytest.fixture(name="registered_class") def fixture_registered_class() -> Tuple[Type[MockClass], str]: """Provide with a type-registered MockClass subclass.""" + # Declare a subtype to avoid side effects between tests. class SubClass(MockClass): # pylint: disable=all pass -- GitLab