Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 5582ac21 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Fix a 'FairbatchDataset.generate_batches' edge case.

Fix the case when the batch size for a given subset is
exactly a multiple of the number of samples for that
subset, which would result in an exception (at least
with in-memory data).
parent 427e617e
No related branches found
No related tags found
1 merge request!69Enable Fairness-Aware Federated Learning
......@@ -212,27 +212,33 @@ class FairbatchDataset(FairnessDataset):
Whether to use poisson sampling rather than batching.
"""
# backend method; pylint: disable=too-many-arguments
args = (shuffle, replacement, poisson)
# Fetch the target sub-dataset and its samples count.
dataset = self.get_sensitive_group_subset(group)
n_samples = self._counts[group]
# Adjust batch size when needed and set up a batches generator.
n_repeats, batch_size = divmod(batch_size, n_samples)
generator = self._generate_batches(
# fmt: off
dataset, group, nb_batches, batch_size,
shuffle, replacement, poisson,
)
# When the dataset is large enough, merely yield batches.
if batch_size <= n_samples:
yield from self._generate_batches(
dataset, group, nb_batches, batch_size, *args
)
# When the batch size is larger than the number of data points,
# make up a base batch will all points (duplicated if needed),
# that will be combined with further batches of data.
if n_repeats:
full = self._get_full_dataset(dataset, n_samples, group)
full = self._concatenate_batches([full] * n_repeats)
for batch in generator:
yield self._concatenate_batches((full, batch))
# Otherwise, merely yield from the generator.
else:
yield from generator
n_repeats, batch_size = divmod(batch_size, n_samples)
# Gather the full subset, optionally duplicated.
full = self._get_full_dataset(dataset, n_samples, group)
if n_repeats > 1:
full = self._concatenate_batches([full] * n_repeats)
# Add up further (batch-varying) samples (when needed).
if batch_size:
for batch in self._generate_batches(
dataset, group, nb_batches, batch_size, *args
):
yield self._concatenate_batches([full, batch])
else: # edge case: require exactly N times the full dataset
for _ in range(nb_batches):
yield full
def _generate_batches(
self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment