Enable skipping frozen weights when using `Model.get_weights` and `Model.set_weights`
This MR tackles issue #15 (closed), namely the fact that the current support for frozen neural network weights is imperfect.
Currently:
-
Model.compute_batch_gradients
properly ignores frozen (i.e. non-trainable) weights -
Model.get_weights
/Model.set_weights
returns/expects all weights, including frozen ones
This causes a core issue: weight-decay and loss Regularizer
instances cause bugs when using models with frozen weights.
This merge request:
- Adds a boolean flag to
Model.get_weights
andModel.set_weights
that enables skipping frozen weights. - Fixes the identified issue thanks to it: weight-decay and loss-regularization terms are computed for trainable weights only.
- Optimizes communications by excluding frozen weights when sending model updates past the initialization phase.
- Fixes potential issues with server-side optimization, that would receive zero-valued updates for non-trainable weights, which in some cases might have resulted in non-zero updates (e.g. due to the use of weight decay).
Implementation tasklist:
-
Implement Model.get_weights(trainable: bool = False)
. -
Deploy the former to Optimizer
backend, fixing the current bug. -
Implement Model.set_weights(trainable: bool = False)
. -
Deploy the former to reduce communication costs as part of the FL process.
Closes #15 (closed)