diff --git a/declearn/fairness/fairbatch/_dataset.py b/declearn/fairness/fairbatch/_dataset.py
index 2949273c0b60394f132faf672e0bb60813e024b1..e6185a953c8158241b3427c869e16487a5d11e1d 100644
--- a/declearn/fairness/fairbatch/_dataset.py
+++ b/declearn/fairness/fairbatch/_dataset.py
@@ -212,27 +212,33 @@ class FairbatchDataset(FairnessDataset):
             Whether to use poisson sampling rather than batching.
         """
         # backend method; pylint: disable=too-many-arguments
+        args = (shuffle, replacement, poisson)
         # Fetch the target sub-dataset and its samples count.
         dataset = self.get_sensitive_group_subset(group)
         n_samples = self._counts[group]
-        # Adjust batch size when needed and set up a batches generator.
-        n_repeats, batch_size = divmod(batch_size, n_samples)
-        generator = self._generate_batches(
-            # fmt: off
-            dataset, group, nb_batches, batch_size,
-            shuffle, replacement, poisson,
-        )
+        # When the dataset is large enough, merely yield batches.
+        if batch_size <= n_samples:
+            yield from self._generate_batches(
+                dataset, group, nb_batches, batch_size, *args
+            )
         # When the batch size is larger than the number of data points,
         # make up a base batch will all points (duplicated if needed),
         # that will be combined with further batches of data.
-        if n_repeats:
-            full = self._get_full_dataset(dataset, n_samples, group)
-            full = self._concatenate_batches([full] * n_repeats)
-            for batch in generator:
-                yield self._concatenate_batches((full, batch))
-        # Otherwise, merely yield from the generator.
         else:
-            yield from generator
+            n_repeats, batch_size = divmod(batch_size, n_samples)
+            # Gather the full subset, optionally duplicated.
+            full = self._get_full_dataset(dataset, n_samples, group)
+            if n_repeats > 1:
+                full = self._concatenate_batches([full] * n_repeats)
+            # Add up further (batch-varying) samples (when needed).
+            if batch_size:
+                for batch in self._generate_batches(
+                    dataset, group, nb_batches, batch_size, *args
+                ):
+                    yield self._concatenate_batches([full, batch])
+            else:  # edge case: require exactly N times the full dataset
+                for _ in range(nb_batches):
+                    yield full
 
     def _generate_batches(
         self,