From 2b020369a21697513204a312a1d4dc7a7c53220d Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Thu, 27 Apr 2023 10:42:40 +0200 Subject: [PATCH] Extend unit tests for type-registration utils. --- test/utils/test_register.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/test/utils/test_register.py b/test/utils/test_register.py index 2cf5085c..5d47b4f4 100644 --- a/test/utils/test_register.py +++ b/test/utils/test_register.py @@ -33,9 +33,13 @@ from declearn.utils import ( def test_create_types_registry() -> None: """Unit tests for 'create_types_registry'.""" group = f"test_{time.time_ns()}" - assert create_types_registry(object, group) is object + + class AnyClass: # pylint: disable=all + pass + + assert create_types_registry(AnyClass, group) is AnyClass with pytest.raises(KeyError): - create_types_registry(object, group) + create_types_registry(AnyClass, group) def test_register_type() -> None: @@ -75,6 +79,9 @@ def test_register_type_fails() -> None: group = f"test_{time.time_ns()}" with pytest.raises(KeyError): register_type(BaseClass, name="base", group=group) + # Try registering in any group, with no valid parent group. + with pytest.raises(TypeError): + register_type(BaseClass, name="base", group=None) # Try registering in a group with wrong class constraints. create_types_registry(BaseClass, group) with pytest.raises(TypeError): @@ -103,10 +110,31 @@ def test_access_registered() -> None: name_2 = f"test_{time.time_ns()}" with pytest.raises(KeyError): access_registered(name_2, group=name) # invalid name under group + with pytest.raises(KeyError): + access_registered(name_2, group=None) # invalid name under any group with pytest.raises(KeyError): access_registered(name, group=name_2) # non-existing group +def test_register_unspecified_group() -> None: + """Unit tests for type-registration with implicit group membership.""" + group = f"test_{time.time_ns()}" + + # Define a parent class and an associted type registry. + @create_types_registry(name=group) + class Parent: # pylint: disable=all + pass + + # Define a child class, and register it without specifying the group. + @register_type(name="new-child") + class Child(Parent): # pylint: disable=all + pass + + # Verify that the class was put into the proper group. + assert access_registered("new-child") is Child + assert access_registration_info(Child) == ("new-child", group) + + def test_access_registeration_info() -> None: """Unit tests for 'access_registration_info'.""" @@ -157,3 +185,7 @@ def test_access_types_mapping() -> None: assert mapping != access_types_mapping(group=group) with pytest.raises(KeyError): access_registered("renamed", group=group) + + # Test that the expected exception is raised for non-existing groups. + with pytest.raises(KeyError): + access_types_mapping(group=f"test_{time.time_ns()}") -- GitLab