From 743998fb7942975fe285baa16f6a64626506a5ec Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 12 Jul 2024 15:28:16 +0200 Subject: [PATCH] Add unit tests for 'FairbatchDataset'. --- .../algorithms/test_fairbatch_dataset.py | 204 ++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 test/fairness/algorithms/test_fairbatch_dataset.py diff --git a/test/fairness/algorithms/test_fairbatch_dataset.py b/test/fairness/algorithms/test_fairbatch_dataset.py new file mode 100644 index 0000000..a749d27 --- /dev/null +++ b/test/fairness/algorithms/test_fairbatch_dataset.py @@ -0,0 +1,204 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FairBatch dataset wrapper.""" + +from unittest import mock + +import numpy as np +import pandas as pd +import pytest + +from declearn.fairness.api import FairnessDataset +from declearn.fairness.core import FairnessInMemoryDataset +from declearn.fairness.fairbatch import FairbatchDataset + + +COUNTS = {(0, 0): 30, (0, 1): 15, (1, 0): 35, (1, 1): 20} + + +class TestFairbatchDataset: + """Unit tests for 'declearn.fairness.fairbatch.FairbatchDataset'.""" + + def setup_mock_base_dataset(self) -> mock.Mock: + """Return a mock FairnessDataset with arbitrary groupwise counts.""" + base = mock.create_autospec(FairnessDataset, instance=True) + base.get_sensitive_group_definitions.return_value = list(COUNTS) + base.get_sensitive_group_counts.return_value = COUNTS + return base + + def test_wrapped_methods(self) -> None: + """Test that API-defined methods are properly wrapped.""" + # Instantiate a FairbatchDataset wrapping a mock FairnessDataset. + base = mock.create_autospec(FairnessDataset, instance=True) + data = FairbatchDataset(base) + # Test API-defined getters. + assert data.get_data_specs() is base.get_data_specs.return_value + assert data.get_sensitive_group_definitions() is ( + base.get_sensitive_group_definitions.return_value + ) + assert data.get_sensitive_group_counts() is ( + base.get_sensitive_group_counts.return_value.copy() + ) + group = mock.create_autospec(tuple, instance=True) + assert data.get_sensitive_group_subset(group) is ( + base.get_sensitive_group_subset.return_value + ) + base.get_sensitive_group_subset.assert_called_once_with(group) + # Test API-defined setter. + weights = mock.create_autospec(dict, instance=True) + adjust_by_counts = mock.create_autospec(bool, instance=True) + data.set_sensitive_group_weights(weights, adjust_by_counts) + base.set_sensitive_group_weights.assert_called_once_with( + weights, adjust_by_counts + ) + + def test_get_sampling_probabilities_initial(self) -> None: + """Test 'get_sampling_probabilities' upon initialization.""" + # Instantiate a FairbatchDataset wrapping a mock FairnessDataset. + base = self.setup_mock_base_dataset() + data = FairbatchDataset(base) + # Access initial sampling probabilities and verify their value. + probas = data.get_sampling_probabilities() + assert isinstance(probas, dict) + assert probas.keys() == COUNTS.keys() + assert all(isinstance(val, float) for val in probas.values()) + expected = {key: 1 / len(COUNTS) for key in COUNTS} + assert probas == expected + + def test_set_sampling_probabilities_simple(self) -> None: + """Test 'set_sampling_probabilities' with matching groups.""" + data = FairbatchDataset(base=self.setup_mock_base_dataset()) + # Assign arbitrary probabilities that match local groups. + probas = {group: idx / 10 for idx, group in enumerate(COUNTS, 1)} + data.set_sampling_probabilities(group_probas=probas) + # Test that inputs were assigned. + assert data.get_sampling_probabilities() == probas + + def test_set_sampling_probabilities_unnormalized(self) -> None: + """Test 'set_sampling_probabilities' with un-normalized values.""" + data = FairbatchDataset(base=self.setup_mock_base_dataset()) + # Assign arbitrary probabilities that do not sum to 1. + probas = {group: float(idx) for idx, group in enumerate(COUNTS, 1)} + expect = {key: val / 10 for key, val in probas.items()} + data.set_sampling_probabilities(group_probas=probas) + # Test that inputs were cprrected, then assigned. + assert data.get_sampling_probabilities() == expect + + def test_set_sampling_probabilities_superset(self) -> None: + """Test 'set_sampling_probabilities' with unrepresented groups.""" + data = FairbatchDataset(base=self.setup_mock_base_dataset()) + # Assign arbitrary probabilities that cover a superset of local groups. + probas = {group: idx / 10 for idx, group in enumerate(COUNTS, 1)} + expect = probas.copy() + probas[(2, 0)] = probas[(2, 1)] = 0.2 + data.set_sampling_probabilities(group_probas=probas) + # Test that inputs were corrected, then assigned. + assert data.get_sampling_probabilities() == expect + + def test_set_sampling_probabilities_invalid_values(self) -> None: + """Test 'set_sampling_probabilities' with negative values.""" + probas = {group: float(idx) for idx, group in enumerate(COUNTS, -2)} + data = FairbatchDataset(base=self.setup_mock_base_dataset()) + with pytest.raises(ValueError): + data.set_sampling_probabilities(group_probas=probas) + + def test_set_sampling_probabilities_invalid_groups(self) -> None: + """Test 'set_sampling_probabilities' with missing groups.""" + probas = { + group: idx / 6 for idx, group in enumerate(list(COUNTS)[1:], 1) + } + data = FairbatchDataset(base=self.setup_mock_base_dataset()) + with pytest.raises(ValueError): + data.set_sampling_probabilities(group_probas=probas) + + def setup_simple_dataset(self) -> FairbatchDataset: + """Set up a simple FairbatchDataset with arbitrary data. + + Samples have a single feature, reflecting the sensitive + group to which they belong. + """ + samples = [ + sample + for idx, (group, n_samples) in enumerate(COUNTS.items()) + for sample in [(group[0], group[1], idx)] * n_samples + ] + base = FairnessInMemoryDataset( + data=pd.DataFrame(samples, columns=["target", "s_attr", "value"]), + f_cols=["value"], + target="target", + s_attr=["s_attr"], + sensitive_target=True, + ) + # Wrap it up as a FairbatchDataset and assign arbitrary probabilities. + return FairbatchDataset(base) + + def test_generate_batches_simple(self) -> None: + """Test that 'generate_batches' has expected behavior.""" + # Setup a simple dataset and assign arbitrary sampling probabilities. + data = self.setup_simple_dataset() + data.set_sampling_probabilities( + {group: idx / 10 for idx, group in enumerate(COUNTS, start=1)} + ) + # Generate batches with a low batch size. + # Verify that outputs match expectations. + batches = list(data.generate_batches(batch_size=10)) + assert len(batches) == 10 + expect_x = np.array( + [[idx] for idx in range(len(COUNTS)) for _ in range(idx + 1)] + ) + expect_y = np.array( + [lab for idx, (lab, _) in enumerate(COUNTS, 1) for _ in range(idx)] + ) + for batch in batches: + assert isinstance(batch, tuple) and (len(batch) == 3) + assert isinstance(batch[0], np.ndarray) + assert (batch[0] == expect_x).all() + assert isinstance(batch[1], np.ndarray) + assert (batch[1] == expect_y).all() + assert batch[2] is None + + def test_generate_batches_large(self) -> None: + """Test that 'generate_batches' has expected behavior.""" + # Setup a simple dataset and assign arbitrary sampling probabilities. + data = self.setup_simple_dataset() + data.set_sampling_probabilities( + {group: idx / 10 for idx, group in enumerate(COUNTS, start=1)} + ) + # Generate batches with a high batch size. + # Verify that outputs match expectations. + batches = list(data.generate_batches(batch_size=100)) + assert len(batches) == 1 + assert isinstance(batches[0][0], np.ndarray) + assert isinstance(batches[0][1], np.ndarray) + assert batches[0][2] is None + expect_x = np.array( + [ + [idx] + for idx in range(len(COUNTS)) + for _ in range(10 * (idx + 1)) + ] + ) + expect_y = np.array( + [ + lab + for idx, (lab, _) in enumerate(COUNTS, 1) + for _ in range(idx * 10) + ] + ) + assert (batches[0][0] == expect_x).all() + assert (batches[0][1] == expect_y).all() -- GitLab