Mentions légales du service

Skip to content

Enhance `declearn.dataset`: implement `TensorflowDataset` and improve `TorchDataset` and `InMemoryDataset`.

ANDREY Paul requested to merge enhance-datasets into develop

This MR is a continuation of the !47 (merged) one, which introduced the declearn.dataset.torch.TorchDataset interface, together with some modifications to declearn.dataset.Dataset and unit tests for its current subclasses.

The aim of this MR was three-fold:

  • Implement the replacement parameter for InMemoryDataset.generate_batches.
    • This parameter was introduced as part of !47 (merged) and implemented for TorchDataset, but left unused in InMemoryDataset.
  • Implement the declearn.dataset.tensorflow.TensorflowDataset counterpart to TorchDataset.
    • This has been a long-due goal, and tackling it will close #21 (closed)
  • Enable wrapping Torch and TensorFlow datasets that yield variable-size samples.
    • The initial implementation of TorchDataset works fine with fixed-size inputs, but does not support padding things on the go.
    • TensorFlow provides with a variety of tools to either pad or stack as ragged tensors variable-size inputs.
    • The objective was to tackle that requirement for both frameworks, with new and similar unit test cases to assert it.

As such, this MR does the following:

  • Implement TensorflowDataset interface to tf.data.Dataset.
  • Implement replacement parameter use for InMemoryDataset.generate_batches.
  • Implement support for TensorFlow and Torch datasets with variable-size sequences of tokens.
    • Implement support for padded and ragged batching in TensorflowDataset.
    • Implement support for custom collate functions in TorchDataset.
    • Implement and expose the collate_with_padding custom collate for use with TorchDataset
  • Add some unit tests for existing and new Dataset.generate_batches features.
    • Add unit tests for batch-generation with replacement.
    • Add unit tests for datasets with variable-size sequences of tokens.
  • Bonus tasks:
    • Implement a custom PoissonSampler for torch, instead of using the opacus one.

Closes #21 (closed)

Merge request reports

Loading