Clarifying our support for multimodal data
We define as "multimodal" any data that are not represented by a single tensor, but rather by (potentially nested) collections of tensors.
For example, MedicalFolderDataset
is twice multimodal: once because it is a tuple of (images, tabular demographics)
and again because the images themselves are a dictionary of `{'imaging modality': image tensor}'.
A lot of healthcare data are expected to be multimodal in this sense. In Fed-BioMed, we want to make it as simple as possible for researchers and nodes to work with multimodal data. Therefore, we wish to have good support for such data.
Multimodal data are one of the reasons why, for example, we allow researchers to define custom training_data
and training_step
functions.
The problem
Supporting such arbitrary collections of data leads to very ugly code and the potential for bugs. A classic example is when we want to discover the shape of a particular dataset: we have several if/else statements to handle tensors, arrays, dictionaries, etc..
Below is a (tentative) list of complications that are related to multimodal data.
- obtaining the batch size in the
TrainingPlan.training_routine
- obtaining the number of samples trained in
Strategy.refine
- sending the data to device in
TorchTrainingPlan
- feeding the data to the ML model
The current situation
Currently we have taken two strategies:
- The researcher's responsibility approach: whenever the researcher can provide custom code to handle multimodal data, they should do it. We use this for example in
training_step
, and also in the customizedMedicalFolderStrategy
that we currently provide in the notebooks - The recursive traversal approach: we can recursively traverse these arbitrary nested collections of tensors until we reach a leaf (i.e. a tensor) and apply our logic there. We do this in the
send_to_device
function, as well as in theinfer_batch_size
.
The first approach should be maintained and used when possible, with the goal of maximing the researcher's interactivity and flexibility. The second approach is often complicated and clunky, and somehow always ends up requiring some ad-hoc tweak to the solution. In evaluating new solutions, we should focus on situations where the second approach is particularly weak, such as the examples on batch size and number of samples above.
Concrete example of the problem
DataLoader
for MedicalFolderDataset
returns tuple where first element is modalities as dict. Properties of the dictionary contains batch of modality images. Therefore, length of a batch is retrieved by accessing single modality while it is retrieved through batch size for mono-modality datasets such as MNIST. This causes to have checks whether the batch is dict or not each time batch size is needed.
Here is the current data structure for each batch,
([
{
"Modality_1" : batch of images as tensor(...),
"Modality_2" : batch of images as tensor(...)
}
demographics,
],
target
)
Researcher selects modality or demographics in training step shown as below:
data[0]["Modality_1"] -> modality
data[1] -> demographics
Proposed solutions
We only support tensors
Of course, a simple solution would just be to ask nodes to convert all of their multimodal data into tensors.
This solution is really sub-optimal because in Fed-BioMed we want to minimize the amount of effort required to upload datasets, and hence we wish to support multimodal data in a format that is as similar as possible to the original raw format.
Third-party library
Both xarray and muData seem quite interesting, and appear to be widely used.
However, third-party libraries bring additional dependencies and potentially conversion issues (both bugs and performance) for using with PyTorch.
Batch
class
Implementing a See comment below. And this photo summarizing our whiteboard discussion