Enable skipping frozen weights when using `Model.get_weights` and `Model.set_weights`
This MR tackles issue #15 (closed), namely the fact that the current support for frozen neural network weights is imperfect.
Currently:
-
Model.compute_batch_gradients
properly ignores frozen (i.e. non-trainable) weights -
Model.get_weights
/Model.set_weights
returns/expects all weights, including frozen ones
This causes a core issue: weight-decay and loss Regularizer
instances cause bugs when using models with frozen weights.
This merge request:
- Adds a boolean flag to
Model.get_weights
andModel.set_weights
that enables skipping frozen weights. - Fixes the identified issue thanks to it: weight-decay and loss-regularization terms are computed for trainable weights only.
- Optimizes communications by excluding frozen weights when sending model updates past the initialization phase.
- Fixes potential issues with server-side optimization, that would receive zero-valued updates for non-trainable weights, which in some cases might have resulted in non-zero updates (e.g. due to the use of weight decay).
Implementation tasklist:
-
Implement Model.get_weights(trainable: bool = False)
. -
Deploy the former to Optimizer
backend, fixing the current bug. -
Implement Model.set_weights(trainable: bool = False)
. -
Deploy the former to reduce communication costs as part of the FL process.
Closes #15 (closed)
Merge request reports
Activity
added api-change bug labels
assigned to @paandrey
added 1 commit
- f1b12bd4 - Add `trainable` argument to `Model.get_weights`.
added 1 commit
- 44c4919f - Add `trainable` argument to `Model.set_weights`.
requested review from @nbigaud
@nbigaud When you have time, could you have a look at this MR?
I would like to add (or at least run locally) some functional test cases before merging, but the core of the modifications is already implemented and unit-tested.
changed milestone to %Release declearn 2.1.0
added 1 commit
- d6af0512 - Fix unit tests for 'Model.apply_updates' with frozen weights.'
added 1 commit
- 33d28c72 - Add 'MLP-tune' test case for 'TorchModel' and 'TensorflowModel'.
I fixed some existing tests and added a minimal frozen-weights case for Tensorflow and Torch. I also pushed some improvement of the
Optimizer.compute_updates_from_gradients
unit tests that, among other things, ensuresmodel.get_weights
is called withtrainable=True
.We could verify the interactions with the Optimizer as part of a more involved functional test, but the current tests are in fact sufficient since we already have some interaction between gradients and trainable-only weights vectors as part of Model tests, and know that the same kind of data will be produced within Optimizer.
added 19 commits
-
5a27cd26...ae0c0796 - 18 commits from branch
develop
- a2f34c14 - Merge branch 'develop' into handle-frozen-weights
-
5a27cd26...ae0c0796 - 18 commits from branch
enabled an automatic merge when the pipeline for a2f34c14 succeeds
removed review request for @nbigaud
mentioned in issue #15 (closed)
mentioned in commit dba20501
mentioned in merge request fedbiomed/fedbiomed!188 (merged)