From 23b96bd13121d105bf4abda5ca709628bfc318ad Mon Sep 17 00:00:00 2001
From: Brahim Erraji <brahim.erraji@inria.fr>
Date: Wed, 6 Nov 2024 11:00:17 +0100
Subject: [PATCH] adding option to calculate optimal loss for different model
 architectures

---
 declearn/fairness/FOG/_client.py |  8 +++++---
 declearn/optimizer/_base.py      | 15 +++------------
 declearn/training/_manager.py    |  6 ++++--
 3 files changed, 12 insertions(+), 17 deletions(-)

diff --git a/declearn/fairness/FOG/_client.py b/declearn/fairness/FOG/_client.py
index 414790e..b8aad05 100644
--- a/declearn/fairness/FOG/_client.py
+++ b/declearn/fairness/FOG/_client.py
@@ -23,7 +23,7 @@ import dataclasses
 import logging
 import os
 import warnings
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union, Literal
 
 import numpy as np
 
@@ -43,7 +43,6 @@ from declearn.secagg.api import Encrypter, SecaggConfigClient, SecaggSetupQuery
 from declearn.secagg import messaging as secagg_messaging
 from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger
 from declearn.main import FederatedClient
-import sys
 
 class Client_FOG(FederatedClient):
     def __init__(
@@ -56,10 +55,12 @@ class Client_FOG(FederatedClient):
         logger: Union[logging.Logger, str, None] = None,
         verbose: bool = True,
         num_local_opt_rounds: int = 10,
-        lmbda: float = 1.0
+        lmbda: float = 1.0,
+        model_type: Literal["linear", "cnn"] = 'cnn'
     ) -> None:
         self.num_local_opt_rounds = num_local_opt_rounds
         self.lmbda = lmbda
+        self.model_type = model_type
         super().__init__(
             netwk=netwk,
             train_data=train_data,
@@ -121,6 +122,7 @@ class Client_FOG(FederatedClient):
                 aggrg=message.aggrg,
                 num_local_rounds = self.num_local_opt_rounds,
                 lmbda = self.lmbda,
+                model_type = self.model_type,
                 train_data=self.train_data,
                 valid_data=self.valid_data,
                 metrics=message.metrics,
diff --git a/declearn/optimizer/_base.py b/declearn/optimizer/_base.py
index a8477da..ded1bec 100644
--- a/declearn/optimizer/_base.py
+++ b/declearn/optimizer/_base.py
@@ -353,7 +353,6 @@ class Optimizer:
         self,
         model: Model[Vector[T]],
         gradients: Vector[T],
-        w_i : Optional[float] = None,
     ) -> Vector[T]:
         """Compute and return model updates based on pre-computed gradients.
 
@@ -367,7 +366,6 @@ class Optimizer:
             Pre-computed vector of (pseudo-)gradients based on which to
             perform the gradient-descent step, by applying the algorithm
             defined by this optimizer and its plug-in modules.
-        w_i: is a optional weight to multiply the gradient by
 
         Returns
         -------
@@ -389,12 +387,7 @@ class Optimizer:
         
         # Apply the base learning rate.
         updates = self.lrate * gradients
-        if w_i is not None:
-            print("&&&&&&&&&&&&&&&")
-            print(w_i)
-            updates = w_i * updates
-            print("The type of the updates is :")
-            print(type(updates))
+        
         # Optionally add the decoupled weight decay term.
         if self.w_decay:
             updates += self.w_decay * weights
@@ -481,7 +474,6 @@ class Optimizer:
         model: Model,
         batch: Batch,
         sclip: Optional[float] = None,
-        w_i : Optional[float] = None,
     ) -> None:
         """Perform a gradient-descent step on a given batch.
 
@@ -502,13 +494,12 @@ class Optimizer:
             This method does not return, as `model` is updated in-place.
         """
         gradients = model.compute_batch_gradients(batch, max_norm=sclip)
-        self.apply_gradients(model, gradients, w_i=w_i)
+        self.apply_gradients(model, gradients)
 
     def apply_gradients(
         self,
         model: Model[Vector[T]],
         gradients: Vector[T],
-        w_i : Optional[float] = None
     ) -> None:
         """Compute and apply model updates based on pre-computed gradients.
 
@@ -526,7 +517,7 @@ class Optimizer:
         None
             This method does not return, as `model` is updated in-place.
         """
-        updates = self.compute_updates_from_gradients(model, gradients, w_i=w_i)
+        updates = self.compute_updates_from_gradients(model, gradients)
         model.apply_updates(updates)
 
     def get_state(
diff --git a/declearn/training/_manager.py b/declearn/training/_manager.py
index b7cbd34..292e32d 100644
--- a/declearn/training/_manager.py
+++ b/declearn/training/_manager.py
@@ -18,7 +18,7 @@
 """Wrapper to run local training and evaluation rounds in a FL process."""
 
 import logging
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union,Literal
 
 import numpy as np
 import pandas as pd
@@ -445,6 +445,7 @@ class TrainingManager_FOG(TrainingManager):
         train_data: Dataset,
         num_local_rounds : int  = 20,
         lmbda : float = 1.0,
+        model_type: Literal["linear", "cnn"] = 'cnn',
         valid_data: Optional[Dataset] = None,
         metrics: Union[MetricSet, List[MetricInputType], None] = None,
         logger: Union[logging.Logger, str, None] = None,
@@ -461,6 +462,7 @@ class TrainingManager_FOG(TrainingManager):
             verbose=verbose
             )
         self.client_id = name
+        self.model_type = model_type
 
         self.lmbda = lmbda
         self.num_local_rounds = num_local_rounds
@@ -471,7 +473,7 @@ class TrainingManager_FOG(TrainingManager):
         print("****** Starting local training ********")
         print(self.client_id)
         opt_metrics_dir = 'OptLosses/'
-        client_path = os.path.join(opt_metrics_dir, self.client_id + f'_{self.num_local_rounds}')
+        client_path = os.path.join(opt_metrics_dir, self.client_id + f'_{self.model_type}_{self.num_local_rounds}')
         if  os.path.exists(client_path):
             df = pd.read_csv(client_path + '/metrics.csv')
             df_dict = df.to_dict()
-- 
GitLab