PoC SCAFFOLD aggregation strategy
Introduction
The purpose of this issue is to highlight the work in progress regarding the implementation of a new FL algorithm / aggregation strategy named SCAFFOLD.
Why use SCAFFOLD ?
Despite being an algorithm of choice for federated learning, it is observed that FedAvg suffers from ‘client-drift’ when the data is heterogeneous (non-iid), resulting in unstable and slow convergence. Karimireddy, Sai Praneeth, et al.,2020.
SCAFFOLD uses control variates (variance reduction) to correct for the ‘client-drift’ in its local updates.
Intuitively, SCAFFOLD estimates the update direction for the server model (c) and the update direction for each client (c_i). The difference (c − c_i) is then an estimate of the client-drift which is used to correct the local update.
Accomplished work
Branch link: https://gitlab.inria.fr/fedbiomed/fedbiomed/-/tree/poc/scaffold
- In fedbiomed/researcher folder:
experiment.py
_set_new_correction_states_dict
: allows us to define each client correction state for round i+1 at round i.
_set_new_client_states_dict
: allows us to define each client state for round i+1 at round i, with scaling of the local parameters by server_lr
.
run_once
method has been modified to transmit correction states to the job, which takes care of starting nodes training round. After optimization, new correction states and new client states are also defined in run_once
.
job.py start_nodes_training_round
method : A new parameter called strategy_info
takes a dictionary as an argument and contains the adopted strategy. In case of SCAFFOLD, it includes the correction states which will be now part of the message sent by request.
New file: scaffold.py server_lr
(server learning rate) as initialization parameter.
- In fedbiomed/common folder:
_torchnn.py training_routine
method : A new parameter called correction_state
takes a dictionary as an argument. This correction state will be used in case of SCAFFOLD strategy. For each batch update, the dot product between the current model state and the client correction state is computed, and the result of this operation is subtracted to the loss, giving us a corrected_loss
. Then, the gradients are computed on this corrected_loss.
message.py correction_state
is now part of the train message sent by the researcher.
utils.py compute_dot_product
function, returning the result of a dot product between a model state, and input parameters (e.g. correction state).
- In fedbiomed/node folder:
node.py parser_task
method : correction_state
is now parsed given a task message, to create a round instance and send this correction in case of SCAFFOLD strategy.
round.py correction_state
is now part of initialization parameters. It will be passed to training kwargs to be applied in the training routine.