Enhance `declearn.dataset`: implement `TensorflowDataset` and improve `TorchDataset` and `InMemoryDataset`.
This MR is a continuation of the !47 (merged) one, which introduced the declearn.dataset.torch.TorchDataset
interface, together with some modifications to declearn.dataset.Dataset
and unit tests for its current subclasses.
The aim of this MR was three-fold:
- Implement the
replacement
parameter forInMemoryDataset.generate_batches
.- This parameter was introduced as part of !47 (merged) and implemented for
TorchDataset
, but left unused inInMemoryDataset
.
- This parameter was introduced as part of !47 (merged) and implemented for
- Implement the
declearn.dataset.tensorflow.TensorflowDataset
counterpart toTorchDataset
.- This has been a long-due goal, and tackling it will close #21 (closed)
- Enable wrapping Torch and TensorFlow datasets that yield variable-size samples.
- The initial implementation of
TorchDataset
works fine with fixed-size inputs, but does not support padding things on the go. - TensorFlow provides with a variety of tools to either pad or stack as ragged tensors variable-size inputs.
- The objective was to tackle that requirement for both frameworks, with new and similar unit test cases to assert it.
- The initial implementation of
As such, this MR does the following:
-
Implement TensorflowDataset
interface totf.data.Dataset
. -
Implement replacement
parameter use forInMemoryDataset.generate_batches
. -
Implement support for TensorFlow and Torch datasets with variable-size sequences of tokens. - Implement support for padded and ragged batching in
TensorflowDataset
. - Implement support for custom collate functions in
TorchDataset
. - Implement and expose the
collate_with_padding
custom collate for use withTorchDataset
- Implement support for padded and ragged batching in
-
Add some unit tests for existing and new Dataset.generate_batches
features.- Add unit tests for batch-generation with replacement.
- Add unit tests for datasets with variable-size sequences of tokens.
- Bonus tasks:
- Implement a custom
PoissonSampler
for torch, instead of using theopacus
one.
- Implement a custom
Closes #21 (closed)