Mini-batch training for scikit-learn
Closes #116 (closed).
This MR introduces significant changes in the training loop and data loading for scikit-learn. The ultimate goal was to enable mini-batch training for scikit-learn models, as suggested by @paandrey in #377 (closed). The primary design goal was to homogenize the training, testing and data loading functionalities between torch and scikit-learn training plans.
Please also review the corresponding documentation changes in fedbiomed.gitlabpages.inria.fr!60.
Below I summarize some of the most relevant changes
NPDataLoader
class to wrap numpy arrays in an interface similar to torch.DataLoader
The NPDataLoader
class has been introduced, and the SKLearnDataManager
has been modified accordingly. Similarly to the TorchDataManager
, we now have the following call stack during training:
-
Round.set_training_testing_data_loaders
calls thetraining_data
function defined by the researcher - the
training_data
function returns aDataManager
- the
Round
callsDataManager.load
to set the instance to aSKLearnDataManager
- then
Round
callsSKLearnDataManager.split
, which returns twoNPDataLoader
objects (for training and testing data)
Iterating over an NPDataLoader
yields a Tuple[np.ndarray, np.ndarray]
, so the recommended way is:
for iter_idx, (data, target) in enumerate(training_dataloader):
do_training(data, target)
The class hierarchy is given below.
New TrainingPlan class hierarchy
Some common functionalities have been moved to the BaseTrainingPlan
class, and the hierarchy for scikit-learn classes has been improved. Below is the UML diagram.
Exceptions
This MR introduces FedbiomedValueError
and FedbiomedTypeError
. There is a tradeoff between losing the information about where the error was originated but gaining more information about the type of error. To be discussed!
Guidelines for MR review
General:
- give a glance to DoD
- check coding rules and coding style
- check docstrings (eg run
tests/docstrings/check_docstrings
)
Specific to some cases:
- update all conda envs consistently (
development
andvpn
, Linux and MacOS) - if modified researcher (eg new attributes in classes) check if breakpoint needs update (
breakpoint
/load_breakpoint
inExperiment()
,save_state
/load_state
in aggregators, strategies, secagg, etc.)