Mentions légales du service

Skip to content

Enable recording training loss values

ANDREY Paul requested to merge record-training-losses into develop

This MR paves the way towards addressing issue #31, by enabling the export of training loss values, which may then be post-processed into plots and/or analysis of how things went.

It implements the following changes:

  • Add Model.collect_training_losses method to the API, implemented at the ABC level.
  • Update the backend of Model subclasses' compute_batch_gradients method to store training loss values in an ABC-defined private list attribute, accessed via collect_training_losses.
  • Add calls to collect training losses at the end of FederatedClient.training_round, and save the values into a JSON file if a Checkpointer is attached to the FederatedClient.

Notes:

  • The Model API change could have felt more natural (and be more explicit) by having compute_batch_gradients return the loss value rather than append it to a private list. However, this would be (severely) API-breaking; hence the choice to use an entirely-new method to access stored value, which also has the positive consequence of enabling to de-couple loss collection from the training loop.
  • At the moment, training losses are collected and saved at the end of the round. This may be changed in favor of saving values as training goes by, enabling to monitor the loss evolution in real time.

Merge request reports