diff --git a/declearn/fairness/fairbatch/_dataset.py b/declearn/fairness/fairbatch/_dataset.py index 2949273c0b60394f132faf672e0bb60813e024b1..e6185a953c8158241b3427c869e16487a5d11e1d 100644 --- a/declearn/fairness/fairbatch/_dataset.py +++ b/declearn/fairness/fairbatch/_dataset.py @@ -212,27 +212,33 @@ class FairbatchDataset(FairnessDataset): Whether to use poisson sampling rather than batching. """ # backend method; pylint: disable=too-many-arguments + args = (shuffle, replacement, poisson) # Fetch the target sub-dataset and its samples count. dataset = self.get_sensitive_group_subset(group) n_samples = self._counts[group] - # Adjust batch size when needed and set up a batches generator. - n_repeats, batch_size = divmod(batch_size, n_samples) - generator = self._generate_batches( - # fmt: off - dataset, group, nb_batches, batch_size, - shuffle, replacement, poisson, - ) + # When the dataset is large enough, merely yield batches. + if batch_size <= n_samples: + yield from self._generate_batches( + dataset, group, nb_batches, batch_size, *args + ) # When the batch size is larger than the number of data points, # make up a base batch will all points (duplicated if needed), # that will be combined with further batches of data. - if n_repeats: - full = self._get_full_dataset(dataset, n_samples, group) - full = self._concatenate_batches([full] * n_repeats) - for batch in generator: - yield self._concatenate_batches((full, batch)) - # Otherwise, merely yield from the generator. else: - yield from generator + n_repeats, batch_size = divmod(batch_size, n_samples) + # Gather the full subset, optionally duplicated. + full = self._get_full_dataset(dataset, n_samples, group) + if n_repeats > 1: + full = self._concatenate_batches([full] * n_repeats) + # Add up further (batch-varying) samples (when needed). + if batch_size: + for batch in self._generate_batches( + dataset, group, nb_batches, batch_size, *args + ): + yield self._concatenate_batches([full, batch]) + else: # edge case: require exactly N times the full dataset + for _ in range(nb_batches): + yield full def _generate_batches( self,