Mentions légales du service

Skip to content

Enable skipping frozen weights when using `Model.get_weights` and `Model.set_weights`

ANDREY Paul requested to merge handle-frozen-weights into develop

This MR tackles issue #15 (closed), namely the fact that the current support for frozen neural network weights is imperfect.

Currently:

  • Model.compute_batch_gradients properly ignores frozen (i.e. non-trainable) weights
  • Model.get_weights / Model.set_weights returns/expects all weights, including frozen ones

This causes a core issue: weight-decay and loss Regularizer instances cause bugs when using models with frozen weights.

This merge request:

  • Adds a boolean flag to Model.get_weights and Model.set_weights that enables skipping frozen weights.
  • Fixes the identified issue thanks to it: weight-decay and loss-regularization terms are computed for trainable weights only.
  • Optimizes communications by excluding frozen weights when sending model updates past the initialization phase.
  • Fixes potential issues with server-side optimization, that would receive zero-valued updates for non-trainable weights, which in some cases might have resulted in non-zero updates (e.g. due to the use of weight decay).

Implementation tasklist:

  • Implement Model.get_weights(trainable: bool = False).
  • Deploy the former to Optimizer backend, fixing the current bug.
  • Implement Model.set_weights(trainable: bool = False).
  • Deploy the former to reduce communication costs as part of the FL process.

Closes #15 (closed)

Edited by ANDREY Paul

Merge request reports