Mentions légales du service

Skip to content

Draft: num updates

CREMONESI Francesco requested to merge feature/369-num-updates into develop

Closes #369 (closed) and #377 (closed).

This MR introduces a new training argument: num_updates. It is equivalent to the number of iterations in the training loop. For mini-batch SGD, this is equivalent to the number of model/optimizer updates, hence the name.

For backward compatibility we still support using epochs, however we encourage using num_updates whenever possible. For this reason I have changed some notebooks, but not all.

From a user point of view, I highly recommend reviewing the associated documentation MR!66.

Implementation details

Major changes in DataLoaders

DataLoaders have significantly different semantics now. In particular, both torch and numpy dataloaders now iterate infinitely over the dataset, building batches that are always of the same size. This is different from the previous behaviour, where dataloaders raised StopIteration at the end of one epoch, and the last batch could potentially be smaller than the specified batch size.

You may find more information on the motivation and implementation details here.

Furthermore, I have factored out DataLoader-related implementation (for both torch and sklearn) in a separate module.

Minor changes in TrainingArgs

We have decided that the following training arguments should have a default value of None. This is done to ensure two things:

  1. we respect the original idea of TrainingArgs to assign a default value to all possible training arguments that a researcher might specify
  2. we can keep track of whether the researcher specified a value for epochs, or for num updates

For more details, see this Discord thread.

Minor changes in the training loop (TrainingPlan)

For both torch and scikit-learn, the training loop is no longer a nested loop (first over epochs, then over batches), but a single loop over the number of updates. This is explained more in detail here.

batch maxnum

If the researcher specified epochs, they are allowed to specify the legacy training argument batch_maxnum. The meaning of this argument is "number of batches per epoch". For example, epochs=2 and batch_maxnum=3 leads to exactly 6 iterations, regardless of the dataset size.

Precedence rules

If the user specifies both epochs and num updates:

  1. num updates takes precedence
  2. epochs will be ignored

For example, epochs=2, batch_maxnum=3, num_updates=17 leads to exactly 17 iterations, because num_updates takes precedence.

Guidelines for MR review


Specific to some cases:

  • update all conda envs consistently (development and vpn, Linux and MacOS)
  • if modified researcher (eg new attributes in classes) check if breakpoint needs update (breakpoint/load_breakpoint in Experiment(), save_state/load_state in aggregators, strategies, secagg, etc.)

Merge request reports