Mentions légales du service

Skip to content

Feature/513 Fix warnings in `TorchModel.set_weights` and BatchNorm layers' handling

This MR addresses issue #513, which is about TorchModel.set_weights generating warnings when the model's state dict comprises tensors or values that are not part of the model's parameters.

The suggested fix is merely to add a secondary filter on these warnings so that it is expected that input weights would not cover these values (and may even exclude non-trainable weights), out of coherence with the get_weights counterpart method's outputs.

Edit: the initial issue emerged from the presence of BatchNorm layers, the handling of which has its importance in a federated context. As such, and based on exchanges that can be found on the issue's page, this MR was also made to:

  • share non-torch.nn.Parameter model states at the end of rounds so that they are aggregated
  • implement a new training arg (as an optional boolean flag, the default value of which is True) to enable not sharing these values (although in the current state of things that is not recommended - see #529)
Edited by ANDREY Paul

Merge request reports