This MR introduces a new training argument:
It is equivalent to the number of iterations in the training loop. For mini-batch SGD, this is equivalent to the number of model/optimizer updates, hence the name.
For backward compatibility we still support using
epochs, however we encourage using
num_updates whenever possible. For this reason I have changed some notebooks, but not all.
From a user point of view, I highly recommend reviewing the associated documentation MR fedbiomed.gitlabpages.inria.fr!66.
DataLoaders have significantly different semantics now. In particular, both torch and numpy dataloaders now iterate infinitely over the dataset, building batches that are always of the same size. This is different from the previous behaviour, where dataloaders raised
StopIteration at the end of one epoch, and the last batch could potentially be smaller than the specified batch size.
You may find more information on the motivation and implementation details here.
Furthermore, I have factored out DataLoader-related implementation (for both torch and sklearn) in a separate
We have decided that the following training arguments should have a default value of
None. This is done to ensure two things:
- we respect the original idea of TrainingArgs to assign a default value to all possible training arguments that a researcher might specify
- we can keep track of whether the researcher specified a value for epochs, or for num updates
For more details, see this Discord thread.
For both torch and scikit-learn, the training loop is no longer a nested loop (first over epochs, then over batches), but a single loop over the number of updates. This is explained more in detail here.
If the researcher specified epochs, they are allowed to specify the legacy training argument
batch_maxnum. The meaning of this argument is "number of batches per epoch". For example, epochs=2 and batch_maxnum=3 leads to exactly 6 iterations, regardless of the dataset size.
If the user specifies both epochs and num updates:
- num updates takes precedence
- epochs will be ignored
For example, epochs=2, batch_maxnum=3, num_updates=17 leads to exactly 17 iterations, because num_updates takes precedence.
Guidelines for MR review
- give a glance to DoD
- check coding rules and coding style
- check docstrings (eg run
Specific to some cases:
- update all conda envs consistently (
vpn, Linux and MacOS)
- if modified researcher (eg new attributes in classes) check if breakpoint needs update (
load_statein aggregators, strategies, secagg, etc.)