Mentions légales du service

Skip to content
Snippets Groups Projects
Commit fae6f6e6 authored by ERRAJI Brahim's avatar ERRAJI Brahim
Browse files

Fixing a math bug

parent 0823b2aa
No related branches found
No related tags found
No related merge requests found
...@@ -508,7 +508,7 @@ class TrainingManager_FOG(TrainingManager): ...@@ -508,7 +508,7 @@ class TrainingManager_FOG(TrainingManager):
): ):
self.n_clients = len(obj_gaps_clients) #This is m - 1 to devide lambda as in algo FOG self.n_clients = len(obj_gaps_clients) #This is m - 1 to devide lambda as in algo FOG
self.r_j = obj_gaps_clients self.r_j = obj_gaps_clients
def get_obj_gaps_other_client( def get_obj_gaps_other_clients(
self, self,
): ):
return self.r_j return self.r_j
...@@ -533,8 +533,8 @@ class TrainingManager_FOG(TrainingManager): ...@@ -533,8 +533,8 @@ class TrainingManager_FOG(TrainingManager):
w_i = 1.0 w_i = 1.0
else: else:
loss_i = nsmp = 0.0 loss_i = nsmp = 0.0
BATCH_SIZE = self.train_data.get_data_specs().n_samples bs = self.train_data.get_data_specs().n_samples
for batch in self.train_data.generate_batches(batch_size=BATCH_SIZE): for batch in self.train_data.generate_batches(batch_size=bs):
y_true, y_pred, s_wght = self.model.compute_batch_predictions(batch) y_true, y_pred, s_wght = self.model.compute_batch_predictions(batch)
s_loss = self.model.loss_function(y_true, y_pred) s_loss = self.model.loss_function(y_true, y_pred)
if s_wght is None: if s_wght is None:
...@@ -544,10 +544,11 @@ class TrainingManager_FOG(TrainingManager): ...@@ -544,10 +544,11 @@ class TrainingManager_FOG(TrainingManager):
loss_i += (s_wght * s_loss).sum() loss_i += (s_wght * s_loss).sum()
nsmp += s_wght.sum() nsmp += s_wght.sum()
loss_i /= nsmp loss_i /= nsmp
r_i = loss_i - self.L_i_theta_s
w_i = 0.0 w_i = 0.0
for loss_j in self.get_obj_gaps_other_client(): for r_j in self.get_obj_gaps_other_clients():
w_i += (loss_i - loss_j) w_i += (r_i - r_j)
w_i = 1 + (4 * (self.lmbda/ self.n_clients) * w_i) w_i = 1 + (4 * self.lmbda * w_i)*(1/self.n_clients)
gradients = self.model.compute_batch_gradients(batch) gradients = self.model.compute_batch_gradients(batch)
updates = self.optim.compute_updates_from_gradients(self.model, gradients) updates = self.optim.compute_updates_from_gradients(self.model, gradients)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment