Mentions légales du service

Skip to content

Mini-batch training for scikit-learn

Closes #116 (closed).

This MR introduces significant changes in the training loop and data loading for scikit-learn. The ultimate goal was to enable mini-batch training for scikit-learn models, as suggested by @paandrey in #377 (closed). The primary design goal was to homogenize the training, testing and data loading functionalities between torch and scikit-learn training plans.

Please also review the corresponding documentation changes in fedbiomed.gitlabpages.inria.fr!60.

Below I summarize some of the most relevant changes

NPDataLoader class to wrap numpy arrays in an interface similar to torch.DataLoader

The NPDataLoader class has been introduced, and the SKLearnDataManager has been modified accordingly. Similarly to the TorchDataManager, we now have the following call stack during training:

  1. Round.set_training_testing_data_loaders calls the training_data function defined by the researcher
  2. the training_data function returns a DataManager
  3. the Round calls DataManager.load to set the instance to a SKLearnDataManager
  4. then Round calls SKLearnDataManager.split, which returns two NPDataLoader objects (for training and testing data)

Iterating over an NPDataLoader yields a Tuple[np.ndarray, np.ndarray], so the recommended way is:

for iter_idx, (data, target) in enumerate(training_dataloader):
    do_training(data, target)

The class hierarchy is given below. UML class hierarchy for SKLearnDataManager and NPDataLoader

New TrainingPlan class hierarchy

Some common functionalities have been moved to the BaseTrainingPlan class, and the hierarchy for scikit-learn classes has been improved. Below is the UML diagram.

UML diagram for training plan classes

Exceptions

This MR introduces FedbiomedValueError and FedbiomedTypeError. There is a tradeoff between losing the information about where the error was originated but gaining more information about the type of error. To be discussed!

Guidelines for MR review

General:

Specific to some cases:

  • update all conda envs consistently (development and vpn, Linux and MacOS)
  • if modified researcher (eg new attributes in classes) check if breakpoint needs update (breakpoint/load_breakpoint in Experiment(), save_state/load_state in aggregators, strategies, secagg, etc.)
Edited by CREMONESI Francesco

Merge request reports