From c379c5fc187d2eaa5f3347e71a6a13dab01b77bb Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 21 Sep 2023 12:01:11 +0200
Subject: [PATCH] Add 'verbose' argument to 'FederatedClient'.

---
 declearn/main/_client.py     | 16 +++++++++++++---
 declearn/quickrun/_run.py    |  4 +++-
 examples/heart-uci/client.py | 13 ++++++++++---
 examples/heart-uci/run.py    |  3 ++-
 examples/mnist/run_client.py |  1 +
 examples/mnist/run_demo.py   | 10 ++++------
 6 files changed, 33 insertions(+), 14 deletions(-)

diff --git a/declearn/main/_client.py b/declearn/main/_client.py
index b712d721..89d0ee04 100644
--- a/declearn/main/_client.py
+++ b/declearn/main/_client.py
@@ -26,7 +26,7 @@ from declearn.communication import NetworkClientConfig, messaging
 from declearn.communication.api import NetworkClient
 from declearn.dataset import Dataset, load_dataset_from_json
 from declearn.main.utils import Checkpointer, TrainingManager
-from declearn.utils import get_logger
+from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger
 
 
 __all__ = [
@@ -47,6 +47,7 @@ class FederatedClient:
         checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None,
         share_metrics: bool = True,
         logger: Union[logging.Logger, str, None] = None,
+        verbose: bool = True,
     ) -> None:
         """Instantiate a client to participate in a federated learning task.
 
@@ -78,6 +79,11 @@ class FederatedClient:
             Logger to use, or name of a logger to set up with
             `declearn.utils.get_logger`.
             If None, use `type(self):netwk.name`.
+        verbose: bool, default=True
+            Whether to verbose about ongoing operations.
+            If True, display progress bars during training and validation
+            rounds. If False and `logger is None`, set the logger's level
+            to filter off most routine information.
         """
         # arguments serve modularity; pylint: disable=too-many-arguments
         # Assign the wrapped NetworkClient.
@@ -98,7 +104,8 @@ class FederatedClient:
         # Assign the logger and optionally replace that of the network client.
         if not isinstance(logger, logging.Logger):
             logger = get_logger(
-                logger or f"{type(self).__name__}-{netwk.name}"
+                name=logger or f"{type(self).__name__}-{netwk.name}",
+                level=logging.INFO if verbose else LOGGING_LEVEL_MAJOR,
             )
         self.logger = logger
         if replace_netwk_logger:
@@ -119,8 +126,9 @@ class FederatedClient:
         if checkpoint is not None:
             checkpoint = Checkpointer.from_specs(checkpoint)
         self.ckptr = checkpoint
-        # Record the metric-sharing boolean switch.
+        # Record the metric-sharing and verbosity bool values.
         self.share_metrics = bool(share_metrics)
+        self.verbose = bool(verbose)
         # Create a TrainingManager slot, populated at initialization phase.
         self.trainmanager = None  # type: Optional[TrainingManager]
 
@@ -262,6 +270,7 @@ class FederatedClient:
             valid_data=self.valid_data,
             metrics=message.metrics,
             logger=self.logger,
+            verbose=self.verbose,
         )
         # If instructed to do so, await a PrivacyRequest to set up DP-SGD.
         if message.dpsgd:
@@ -335,6 +344,7 @@ class FederatedClient:
             valid_data=self.trainmanager.valid_data,
             metrics=self.trainmanager.metrics,
             logger=self.trainmanager.logger,
+            verbose=self.trainmanager.verbose,
         )
         self.trainmanager.make_private(message)
 
diff --git a/declearn/quickrun/_run.py b/declearn/quickrun/_run.py
index 46f14384..c1b05fd2 100644
--- a/declearn/quickrun/_run.py
+++ b/declearn/quickrun/_run.py
@@ -136,7 +136,9 @@ def run_client(
         paths.get("valid_data"),
         target=paths.get("valid_target"),
     )
-    client = FederatedClient(network, train, valid, checkpoint, logger=logger)
+    client = FederatedClient(
+        network, train, valid, checkpoint, logger=logger, verbose=False
+    )
     client.run()
 
 
diff --git a/examples/heart-uci/client.py b/examples/heart-uci/client.py
index 3713b310..7a91772e 100644
--- a/examples/heart-uci/client.py
+++ b/examples/heart-uci/client.py
@@ -26,7 +26,7 @@ from declearn.communication import NetworkClientConfig
 from declearn.dataset import InMemoryDataset
 from declearn.dataset.examples import load_heart_uci
 from declearn.main import FederatedClient
-from declearn.test_utils import make_importable, setup_client_argparse
+from declearn.test_utils import setup_client_argparse
 
 
 FILEDIR = os.path.dirname(__file__)
@@ -37,6 +37,7 @@ def run_client(
     ca_cert: str,
     protocol: str = "websockets",
     serv_uri: str = "wss://localhost:8765",
+    verbose: bool = True,
 ) -> None:
     """Instantiate and run a given client.
 
@@ -51,6 +52,9 @@ def run_client(
         Name of the communication protocol to use.
     serv_uri: str, default="wss://localhost:8765"
         URI of the server to which to connect.
+    verbose: bool, default=True
+        Whether to be verbose in the displayed contents, including
+        all logger information and progress bars.
     """
 
     # (1-2) Interface training and optional validation data.
@@ -87,8 +91,11 @@ def run_client(
     # (5) Instantiate a FederatedClient and run it.
 
     client = FederatedClient(
-        # fmt: off
-        network, train, valid, checkpoint=f"{FILEDIR}/results/{name}"
+        netwk=network,
+        train_data=train,
+        valid_data=valid,
+        checkpoint=f"{FILEDIR}/results/{name}",
+        verbose=verbose,
         # Note: you may add `share_metrics=False` to prevent sending
         # evaluation metrics to the server, out of privacy concerns
     )
diff --git a/examples/heart-uci/run.py b/examples/heart-uci/run.py
index 1a35c362..57937dd4 100644
--- a/examples/heart-uci/run.py
+++ b/examples/heart-uci/run.py
@@ -45,7 +45,8 @@ def run_demo(
         # Specify the server and client routines that need executing.
         server = (run_server, (nb_clients, sv_cert, sv_pkey))
         clients = [
-            (run_client, (name, ca_cert)) for name in NAMES[:nb_clients]
+            (run_client, {"name": name, "ca_cert": ca_cert, "verbose": False})
+            for name in NAMES[:nb_clients]
         ]
         # Run routines in isolated processes. Raise if any failed.
         success, outp = run_as_processes(server, *clients)
diff --git a/examples/mnist/run_client.py b/examples/mnist/run_client.py
index b8d84b94..3d1f2319 100644
--- a/examples/mnist/run_client.py
+++ b/examples/mnist/run_client.py
@@ -115,6 +115,7 @@ def run_client(
         valid_data=valid,
         checkpoint=checkpoint,
         logger=logger,
+        verbose=verbose,
     )
     client.run()
 
diff --git a/examples/mnist/run_demo.py b/examples/mnist/run_demo.py
index ebe5c0c3..fc8c526d 100644
--- a/examples/mnist/run_demo.py
+++ b/examples/mnist/run_demo.py
@@ -44,12 +44,10 @@ def run_demo(
 
     Parameters
     ------
-
     n_clients: int
         number of clients to run.
     data_folder: str
         Relative path to the folder holding client's data
-
     """
     # Generate the MNIST split data for this demo.
     data_folder = prepare_mnist(nb_clients, scheme, seed=seed)
@@ -59,11 +57,11 @@ def run_demo(
         ca_cert, sv_cert, sv_pkey = generate_ssl_certificates(folder)
         # Specify the server and client routines that need executing.
         server = (run_server, (nb_clients, sv_cert, sv_pkey))
-        client_args = tuple(
-            [data_folder, ca_cert, "websockets", "wss://localhost:8765", False]
-        )
+        client_kwargs = {
+            "data_folder": data_folder, "ca_cert": ca_cert, "verbose": False
+        }
         clients = [
-            (run_client, (f"client_{idx}", *client_args))
+            (run_client, (f"client_{idx}",), client_kwargs)
             for idx in range(nb_clients)
         ]
         # Run routines in isolated processes. Raise if any failed.
-- 
GitLab