Num updates
Closes #369 (closed) and #377 (closed) and #435 (closed) .
This MR introduces a new training argument: num_updates
. It is equivalent to the number of iterations in the training loop. For mini-batch SGD, this is equivalent to the number of model/optimizer updates, hence the name.
For backward compatibility we still support using epochs
, however we encourage using num_updates
whenever possible. For this reason I have changed some notebooks, but not all.
Note that num_updates
will yield a number of iterations that is the same for every node, while epochs
will yield different number of iterations per node.
From a user point of view, I highly recommend reviewing the associated documentation MR fedbiomed.gitlabpages.inria.fr!66.
Implementation details
An history of the (difficult) implementation discussion can be found here and in MR !158 (closed).
MiniBatchTrainingIterationsAccountant
New class This class has the following responsibilities:
- manage iterators for epochs and batches
- provide up-to-date values for reporting
- handle different semantics in case the researcher asked for num_updates or epochs
The idea of this class is to be able to always write the training loop within a training plan as such:
iterations_accountant = MiniBatchTrainingIterationsAccountant(self) # self is the training plan instance
for epoch in iterations_accountant.iterate_epochs():
training_data_iterator = iter(self.training_data_loader) # we need this to be able to access the actual data
for batch in iterations_accountant.iterate_batches():
data, target = next(training_data_iterator) # retrieve the actual data
iterations_accountant.increment_sample_counters(batch_size) # update internal reporting attributes
# ... do the training
if iterations_accountant.should_log_this_batch():
# Retrieve reporting information
num_samples, num_samples_max = iterations_accountant.reporting_on_num_samples()
num_iter, num_iter_max = iterations_accountant.reporting_on_num_iter()
epoch_to_report = iterations_accountant.reporting_on_epoch()
# do the reporting
The loop above illustrates both the use of this class to manage the iterations, as well as its use to simplify reporting.
num_updates
and epochs
Handling both For handling both num_updates
and epochs
, the implementation relies on a simple idea: converting num_updates
into epochs on each node. Therefore the training loop inside the training plan should always be implemented in terms of epochs and batches.
The logic of this conversion can be found inside MiniBatchTrainingIterationsAccountant._n_training_iterations
, a function whose purpose is to provide all the necessary information to manage the number of iterations in the training loop.
The signature of this function is
def _n_training_iterations(self):
The function does not return anything, but it updates three internal attributes:
- number of epochs
- number of batches in last epoch, i.e. an additional number of iterations to be performed in the last epoch if
num_updates
was defined - number of batches per epoch
Note: we always perform one additional epoch. This additional epoch may be empty. This is done to simplify the implementation.
Here are some examples. Consider a dataloader with a length of 4. Then
training args | epochs | batches in last epoch | batches per epoch |
---|---|---|---|
{'epochs': 1} |
2 | 0 | 4 |
{'epochs': 1, 'batch_maxnum': 2} |
2 | 0 | 2 |
{'num_updates': 7} |
2 | 3 | 4 |
{'num_updates': 3} |
1 | 3 | 4 |
Minor changes in TrainingArgs
We have decided that the training arguments epochs
, batch_maxnum
, and num_updates
should have a default value of None
.
This is done to ensure two things:
- we respect the original idea of TrainingArgs to assign a default value to all possible training arguments that a researcher might specify
- we can keep track of whether the researcher specified a value for epochs, or for num updates
For more details, see this Discord thread.
batch maxnum
If the researcher specified epochs, they are allowed to specify the legacy training argument batch_maxnum
. The meaning of this argument is "number of batches per epoch". For example, epochs=2 and batch_maxnum=3 leads to exactly 6 iterations, regardless of the dataset size.
If the researcher specified num updates, then batch_maxnum
is ignored.
Precedence rules
If the user specifies both epochs and num updates:
- num updates takes precedence
- epochs (and batch_maxnum) will be ignored
For example, epochs=2, batch_maxnum=3, num_updates=17 leads to exactly 17 iterations, because num_updates takes precedence.
Validation
The MiniBatchTrainingIterationsAccountant
class is not currently used for the validation loop (although it could be with minor additions).
This is because the validation loop does not have complicated logic to handle (e.g. num_updates
does not apply). Most of the time, the validation loop is literally a single epoch, single batch that contains the whole validation dataset.
However, this MR also fixes a reporting bug on validation, since I was working tangentially also on that.
Opacus
The current implementation reports the exact number of samples observed during training to both the researcher and the node. There is an inconsistency in the total reported number of samples: with Opacus, under certain conditions we cannot know in advance the exact number of samples that will be observed. For example, when batch_maxnum is used, or when num_updates is used. The solution adopted in this implementation is to:
- always report the "predicted" total number of samples, except...
- ...at the last iteration, where we update the total number of samples with the real total number of samples
Developer Certificate Of Origin (DCO)
By opening this merge request, you agree the Developer Certificate of Origin (DCO)
This DCO essentially means that:
- you offer the changes under the same license agreement as the project, and
- you have the right to do that,
- you did not steal somebody else’s work.
License
Project code files should begin with these comment lines to help trace their origin:
# This file is originally part of Fed-BioMed
# SPDX-License-Identifier: Apache-2.0
Code files can be reused from another project with a compatible non-contaminating license.
They shall retain the original license and copyright mentions.
The CREDIT.md
file and credit/
directory shall be completed and updated accordingly.
Guidelines for MR review
General:
- give a glance to DoD
- check coding rules and coding style
- check docstrings (eg run
tests/docstrings/check_docstrings
)
Specific to some cases:
- update all conda envs consistently (
development
andvpn
, Linux and MacOS) - if modified researcher (eg new attributes in classes) check if breakpoint needs update (
breakpoint
/load_breakpoint
inExperiment()
,save_state
/load_state
in aggregators, strategies, secagg, etc.)