Mentions légales du service

Skip to content

Num updates

CREMONESI Francesco requested to merge feature/369-num-updates-final into develop

Closes #369 (closed) and #377 (closed) and #435 (closed) .

This MR introduces a new training argument: num_updates. It is equivalent to the number of iterations in the training loop. For mini-batch SGD, this is equivalent to the number of model/optimizer updates, hence the name.

For backward compatibility we still support using epochs, however we encourage using num_updates whenever possible. For this reason I have changed some notebooks, but not all.

Note that num_updates will yield a number of iterations that is the same for every node, while epochs will yield different number of iterations per node.

From a user point of view, I highly recommend reviewing the associated documentation MR fedbiomed.gitlabpages.inria.fr!66.

Implementation details

An history of the (difficult) implementation discussion can be found here and in MR !158 (closed).

New class MiniBatchTrainingIterationsAccountant

This class has the following responsibilities:

    - manage iterators for epochs and batches
    - provide up-to-date values for reporting
    - handle different semantics in case the researcher asked for num_updates or epochs

The idea of this class is to be able to always write the training loop within a training plan as such:

iterations_accountant = MiniBatchTrainingIterationsAccountant(self)  # self is the training plan instance
for epoch in iterations_accountant.iterate_epochs():
    training_data_iterator = iter(self.training_data_loader)  # we need this to be able to access the actual data
    for batch in iterations_accountant.iterate_batches():
        data, target = next(training_data_iterator)  # retrieve the actual data
        iterations_accountant.increment_sample_counters(batch_size)  # update internal reporting attributes
        # ... do the training
        if iterations_accountant.should_log_this_batch():
            # Retrieve reporting information
            num_samples, num_samples_max = iterations_accountant.reporting_on_num_samples()
            num_iter, num_iter_max = iterations_accountant.reporting_on_num_iter()
            epoch_to_report = iterations_accountant.reporting_on_epoch()
            # do the reporting

The loop above illustrates both the use of this class to manage the iterations, as well as its use to simplify reporting.

Handling both num_updates and epochs

For handling both num_updates and epochs, the implementation relies on a simple idea: converting num_updates into epochs on each node. Therefore the training loop inside the training plan should always be implemented in terms of epochs and batches.

The logic of this conversion can be found inside MiniBatchTrainingIterationsAccountant._n_training_iterations, a function whose purpose is to provide all the necessary information to manage the number of iterations in the training loop. The signature of this function is

def _n_training_iterations(self):

The function does not return anything, but it updates three internal attributes:

  1. number of epochs
  2. number of batches in last epoch, i.e. an additional number of iterations to be performed in the last epoch if num_updates was defined
  3. number of batches per epoch

Note: we always perform one additional epoch. This additional epoch may be empty. This is done to simplify the implementation.

Here are some examples. Consider a dataloader with a length of 4. Then

training args epochs batches in last epoch batches per epoch
{'epochs': 1} 2 0 4
{'epochs': 1, 'batch_maxnum': 2} 2 0 2
{'num_updates': 7} 2 3 4
{'num_updates': 3} 1 3 4

Minor changes in TrainingArgs

We have decided that the training arguments epochs, batch_maxnum, and num_updates should have a default value of None. This is done to ensure two things:

  1. we respect the original idea of TrainingArgs to assign a default value to all possible training arguments that a researcher might specify
  2. we can keep track of whether the researcher specified a value for epochs, or for num updates

For more details, see this Discord thread.

batch maxnum

If the researcher specified epochs, they are allowed to specify the legacy training argument batch_maxnum. The meaning of this argument is "number of batches per epoch". For example, epochs=2 and batch_maxnum=3 leads to exactly 6 iterations, regardless of the dataset size. If the researcher specified num updates, then batch_maxnum is ignored.

Precedence rules

If the user specifies both epochs and num updates:

  1. num updates takes precedence
  2. epochs (and batch_maxnum) will be ignored

For example, epochs=2, batch_maxnum=3, num_updates=17 leads to exactly 17 iterations, because num_updates takes precedence.

Validation

The MiniBatchTrainingIterationsAccountant class is not currently used for the validation loop (although it could be with minor additions). This is because the validation loop does not have complicated logic to handle (e.g. num_updates does not apply). Most of the time, the validation loop is literally a single epoch, single batch that contains the whole validation dataset. However, this MR also fixes a reporting bug on validation, since I was working tangentially also on that.

Opacus

The current implementation reports the exact number of samples observed during training to both the researcher and the node. There is an inconsistency in the total reported number of samples: with Opacus, under certain conditions we cannot know in advance the exact number of samples that will be observed. For example, when batch_maxnum is used, or when num_updates is used. The solution adopted in this implementation is to:

  • always report the "predicted" total number of samples, except...
  • ...at the last iteration, where we update the total number of samples with the real total number of samples

Developer Certificate Of Origin (DCO)

By opening this merge request, you agree the Developer Certificate of Origin (DCO)

This DCO essentially means that:

  • you offer the changes under the same license agreement as the project, and
  • you have the right to do that,
  • you did not steal somebody else’s work.

License

Project code files should begin with these comment lines to help trace their origin:

# This file is originally part of Fed-BioMed
# SPDX-License-Identifier: Apache-2.0

Code files can be reused from another project with a compatible non-contaminating license. They shall retain the original license and copyright mentions. The CREDIT.md file and credit/ directory shall be completed and updated accordingly.

Guidelines for MR review

General:

Specific to some cases:

  • update all conda envs consistently (development and vpn, Linux and MacOS)
  • if modified researcher (eg new attributes in classes) check if breakpoint needs update (breakpoint/load_breakpoint in Experiment(), save_state/load_state in aggregators, strategies, secagg, etc.)
Edited by CREMONESI Francesco

Merge request reports