Draft: num updates
Closes #369 (closed) and #377 (closed).
This MR introduces a new training argument: num_updates
.
It is equivalent to the number of iterations in the training loop. For mini-batch SGD, this is equivalent to the number of model/optimizer updates, hence the name.
For backward compatibility we still support using epochs
, however we encourage using num_updates
whenever possible. For this reason I have changed some notebooks, but not all.
From a user point of view, I highly recommend reviewing the associated documentation MR fedbiomed.gitlabpages.inria.fr!66.
Implementation details
Major changes in DataLoaders
DataLoaders have significantly different semantics now. In particular, both torch and numpy dataloaders now iterate infinitely over the dataset, building batches that are always of the same size. This is different from the previous behaviour, where dataloaders raised StopIteration
at the end of one epoch, and the last batch could potentially be smaller than the specified batch size.
You may find more information on the motivation and implementation details here.
Furthermore, I have factored out DataLoader-related implementation (for both torch and sklearn) in a separate fedbiomed.common.data.loaders
module.
Minor changes in TrainingArgs
We have decided that the following training arguments should have a default value of None
. This is done to ensure two things:
- we respect the original idea of TrainingArgs to assign a default value to all possible training arguments that a researcher might specify
- we can keep track of whether the researcher specified a value for epochs, or for num updates
For more details, see this Discord thread.
Minor changes in the training loop (TrainingPlan)
For both torch and scikit-learn, the training loop is no longer a nested loop (first over epochs, then over batches), but a single loop over the number of updates. This is explained more in detail here.
batch maxnum
If the researcher specified epochs, they are allowed to specify the legacy training argument batch_maxnum
. The meaning of this argument is "number of batches per epoch". For example, epochs=2 and batch_maxnum=3 leads to exactly 6 iterations, regardless of the dataset size.
Precedence rules
If the user specifies both epochs and num updates:
- num updates takes precedence
- epochs will be ignored
For example, epochs=2, batch_maxnum=3, num_updates=17 leads to exactly 17 iterations, because num_updates takes precedence.
Guidelines for MR review
General:
- give a glance to DoD
- check coding rules and coding style
- check docstrings (eg run
tests/docstrings/check_docstrings
)
Specific to some cases:
- update all conda envs consistently (
development
andvpn
, Linux and MacOS) - if modified researcher (eg new attributes in classes) check if breakpoint needs update (
breakpoint
/load_breakpoint
inExperiment()
,save_state
/load_state
in aggregators, strategies, secagg, etc.)