From 2b020369a21697513204a312a1d4dc7a7c53220d Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 27 Apr 2023 10:42:40 +0200
Subject: [PATCH] Extend unit tests for type-registration utils.

---
 test/utils/test_register.py | 36 ++++++++++++++++++++++++++++++++++--
 1 file changed, 34 insertions(+), 2 deletions(-)

diff --git a/test/utils/test_register.py b/test/utils/test_register.py
index 2cf5085c..5d47b4f4 100644
--- a/test/utils/test_register.py
+++ b/test/utils/test_register.py
@@ -33,9 +33,13 @@ from declearn.utils import (
 def test_create_types_registry() -> None:
     """Unit tests for 'create_types_registry'."""
     group = f"test_{time.time_ns()}"
-    assert create_types_registry(object, group) is object
+
+    class AnyClass:  # pylint: disable=all
+        pass
+
+    assert create_types_registry(AnyClass, group) is AnyClass
     with pytest.raises(KeyError):
-        create_types_registry(object, group)
+        create_types_registry(AnyClass, group)
 
 
 def test_register_type() -> None:
@@ -75,6 +79,9 @@ def test_register_type_fails() -> None:
     group = f"test_{time.time_ns()}"
     with pytest.raises(KeyError):
         register_type(BaseClass, name="base", group=group)
+    # Try registering in any group, with no valid parent group.
+    with pytest.raises(TypeError):
+        register_type(BaseClass, name="base", group=None)
     # Try registering in a group with wrong class constraints.
     create_types_registry(BaseClass, group)
     with pytest.raises(TypeError):
@@ -103,10 +110,31 @@ def test_access_registered() -> None:
     name_2 = f"test_{time.time_ns()}"
     with pytest.raises(KeyError):
         access_registered(name_2, group=name)  # invalid name under group
+    with pytest.raises(KeyError):
+        access_registered(name_2, group=None)  # invalid name under any group
     with pytest.raises(KeyError):
         access_registered(name, group=name_2)  # non-existing group
 
 
+def test_register_unspecified_group() -> None:
+    """Unit tests for type-registration with implicit group membership."""
+    group = f"test_{time.time_ns()}"
+
+    # Define a parent class and an associted type registry.
+    @create_types_registry(name=group)
+    class Parent:  # pylint: disable=all
+        pass
+
+    # Define a child class, and register it without specifying the group.
+    @register_type(name="new-child")
+    class Child(Parent):  # pylint: disable=all
+        pass
+
+    # Verify that the class was put into the proper group.
+    assert access_registered("new-child") is Child
+    assert access_registration_info(Child) == ("new-child", group)
+
+
 def test_access_registeration_info() -> None:
     """Unit tests for 'access_registration_info'."""
 
@@ -157,3 +185,7 @@ def test_access_types_mapping() -> None:
     assert mapping != access_types_mapping(group=group)
     with pytest.raises(KeyError):
         access_registered("renamed", group=group)
+
+    # Test that the expected exception is raised for non-existing groups.
+    with pytest.raises(KeyError):
+        access_types_mapping(group=f"test_{time.time_ns()}")
-- 
GitLab