diff --git a/test/data_info/test_classes_field.py b/test/data_info/test_classes_field.py new file mode 100644 index 0000000000000000000000000000000000000000..98addb7cf01db28b79cfffbc1d366751d0e21e42 --- /dev/null +++ b/test/data_info/test_classes_field.py @@ -0,0 +1,59 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.data_info.ClassesField'.""" + + +import numpy as np +import pytest + +from declearn.data_info import ClassesField + + +class TestClassesField: + """Unit tests for 'declearn.data_info.ClassesField'.""" + + def test_is_valid_list(self) -> None: + """Test `is_valid` with a valid list value.""" + assert ClassesField.is_valid([0, 1]) + + def test_is_valid_set(self) -> None: + """Test `is_valid` with a valid set value.""" + assert ClassesField.is_valid({0, 1}) + + def test_is_valid_tuple(self) -> None: + """Test `is_valid` with a valid tuple value.""" + assert ClassesField.is_valid((0, 1)) + + def test_is_valid_array(self) -> None: + """Test `is_valid` with a valid numpy array value.""" + assert ClassesField.is_valid(np.array([0, 1])) + + def test_is_invalid_2d_array(self) -> None: + """Test `is_valid` with an invalid numpy array value.""" + assert not ClassesField.is_valid(np.array([[0, 1], [2, 3]])) + + def test_combine(self) -> None: + """Test `combine` with valid and compatible inputs.""" + values = ([0, 1], (0, 1), {1, 2}, np.array([1, 3])) + assert ClassesField.combine(*values) == {0, 1, 2, 3} + + def test_combine_fails(self) -> None: + """Test `combine` with some invalid inputs.""" + values = ([0, 1], np.array([[0, 1], [2, 3]])) + with pytest.raises(ValueError): + ClassesField.combine(*values) diff --git a/test/data_info/test_datatype_field.py b/test/data_info/test_datatype_field.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c2f1bf76c11fcea1ff5d628374e7dddeb6115a --- /dev/null +++ b/test/data_info/test_datatype_field.py @@ -0,0 +1,57 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.data_info.DataTypeField'.""" + + +import numpy as np +import pytest + +from declearn.data_info import DataTypeField + + +class TestDataTypeField: + """Unit tests for 'declearn.data_info.DataTypeField'.""" + + def test_is_valid(self) -> None: + """Test `is_valid` with some valid values.""" + assert DataTypeField.is_valid("float32") + assert DataTypeField.is_valid("float64") + assert DataTypeField.is_valid("int32") + assert DataTypeField.is_valid("uint8") + + def test_is_not_valid(self) -> None: + """Test `is_valid` with invalid values.""" + assert not DataTypeField.is_valid(np.int32) + assert not DataTypeField.is_valid("mocktype") + + def test_combine(self) -> None: + """Test `combine` with valid and compatible inputs.""" + values = ["float32", "float32"] + assert DataTypeField.combine(*values) == "float32" + + def test_combine_invalid(self) -> None: + """Test `combine` with invalid inputs.""" + values = ["float32", "mocktype"] + with pytest.raises(ValueError): + DataTypeField.combine(*values) + + def test_combine_incompatible(self) -> None: + """Test `combine` with incompatible inputs.""" + values = ["float32", "float16"] + with pytest.raises(ValueError): + DataTypeField.combine(*values) diff --git a/test/data_info/test_nbsamples_field.py b/test/data_info/test_nbsamples_field.py new file mode 100644 index 0000000000000000000000000000000000000000..fc53f3dfdc31dd29ef21b9d0a2bb524034f5e739 --- /dev/null +++ b/test/data_info/test_nbsamples_field.py @@ -0,0 +1,52 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.data_info.NbSamplesField'.""" + + +import pytest + +from declearn.data_info import NbSamplesField + + +class TestNbSamplesField: + """Unit tests for 'declearn.data_info.NbSamplesField'.""" + + def test_is_valid(self) -> None: + """Test `is_valid` with some valid input values.""" + assert NbSamplesField.is_valid(32) + assert NbSamplesField.is_valid(100) + assert NbSamplesField.is_valid(8192) + + def test_is_not_valid(self) -> None: + """Test `is_valid` with invalid values.""" + assert not NbSamplesField.is_valid(16.5) + assert not NbSamplesField.is_valid(-12) + assert not NbSamplesField.is_valid(None) + + def test_combine(self) -> None: + """Test `combine` with valid and compatible inputs.""" + values = [32, 128] + assert NbSamplesField.combine(*values) == 160 + values = [64, 64, 64, 64] + assert NbSamplesField.combine(*values) == 256 + + def test_combine_invalid(self) -> None: + """Test `combine` with invalid inputs.""" + values = [128, -12] + with pytest.raises(ValueError): + NbSamplesField.combine(*values) diff --git a/test/data_info/test_shape_field.py b/test/data_info/test_shape_field.py new file mode 100644 index 0000000000000000000000000000000000000000..ef9f61c4a06c28cb5d31a31fad560109a6d95089 --- /dev/null +++ b/test/data_info/test_shape_field.py @@ -0,0 +1,67 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for 'declearn.data_info.FeaturesShapeField'.""" + + +import pytest + +from declearn.data_info import FeaturesShapeField + + +class TestFeaturesShapeField: + """Unit tests for 'declearn.data_info.FeaturesShapeField'.""" + + def test_is_valid(self) -> None: + """Test `is_valid` with some valid input values.""" + # 1-d ; fixed 3-d (image-like) ; variable 2-d (text-like). + assert FeaturesShapeField.is_valid([32]) + assert FeaturesShapeField.is_valid([64, 64, 3]) + assert FeaturesShapeField.is_valid([None, 128]) + # Same inputs, as tuples. + assert FeaturesShapeField.is_valid((32,)) + assert FeaturesShapeField.is_valid((64, 64, 3)) + assert FeaturesShapeField.is_valid((None, 128)) + + def test_is_not_valid(self) -> None: + """Test `is_valid` with invalid values.""" + assert not FeaturesShapeField.is_valid(32) + assert not FeaturesShapeField.is_valid([32, -1]) + + def test_combine(self) -> None: + """Test `combine` with valid and compatible inputs.""" + # 1-d inputs. + values = [[32], (32,)] + assert FeaturesShapeField.combine(*values) == (32,) + # 3-d fixed-size inputs. + values = [[16, 16, 3], (16, 16, 3)] + assert FeaturesShapeField.combine(*values) == (16, 16, 3) + # 2-d variable-size inputs. + values = [[None, 512], (None, 512)] # type: ignore + assert FeaturesShapeField.combine(*values) == (None, 512) + + def test_combine_invalid(self) -> None: + """Test `combine` with invalid inputs.""" + values = [[32], [32, -1]] + with pytest.raises(ValueError): + FeaturesShapeField.combine(*values) + + def test_combine_incompatible(self) -> None: + """Test `combine` with incompatible inputs.""" + values = [(None, 32), (128,)] + with pytest.raises(ValueError): + FeaturesShapeField.combine(*values)