From 7166e0a6194f41c9ed1cfa27d78b3c47aef82021 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Tue, 10 Oct 2023 16:01:24 +0200
Subject: [PATCH] Revise 'declearn-quickrun' backend to use asyncio.

---
 declearn/quickrun/_run.py           | 75 +++++++++++++++--------------
 examples/mnist_quickrun/mnist.ipynb |  4 +-
 2 files changed, 41 insertions(+), 38 deletions(-)

diff --git a/declearn/quickrun/_run.py b/declearn/quickrun/_run.py
index c1b05fd2..b2d30b3c 100644
--- a/declearn/quickrun/_run.py
+++ b/declearn/quickrun/_run.py
@@ -30,6 +30,7 @@ The script then locally runs the FL experiment as layed out in the TOML file,
 using privided model and data, and stores its result in the same folder.
 """
 
+import asyncio
 import importlib
 import logging
 import os
@@ -53,7 +54,6 @@ from declearn.test_utils import make_importable
 from declearn.utils import (
     LOGGING_LEVEL_MAJOR,
     get_logger,
-    run_as_processes,
     set_device_policy,
 )
 
@@ -61,19 +61,26 @@ __all__ = ["quickrun"]
 
 
 def get_model(folder: str, model_config: ModelConfig) -> Model:
-    "Return a model instance from a model config instance"
+    """Return a model instance from a model config instance."""
     path = model_config.model_file or os.path.join(folder, "model.py")
     path = os.path.abspath(path)
     if not os.path.isfile(path):
         raise FileNotFoundError("Model file not found: '{path}'.")
+    if not path.endswith(".py"):
+        raise TypeError(f"Model file at '{path}' is not a '.py' file.")
     with make_importable(os.path.dirname(path)):
         mod = importlib.import_module(os.path.basename(path)[:-3])
         model = getattr(mod, model_config.model_name)
+    if not isinstance(model, Model):
+        raise TypeError(
+            "Imported object from the model file is required to be a "
+            "'declearn.model.api.Model', but is a '{type(model)}'."
+        )
     return model
 
 
 def get_checkpoint(folder: str, expe_config: ExperimentConfig) -> str:
-    """Return the checkpoint folder, either default or as given in config"""
+    """Return the checkpoint folder, either default or as given in config."""
     if expe_config.checkpoint:
         checkpoint = expe_config.checkpoint
     else:
@@ -82,7 +89,7 @@ def get_checkpoint(folder: str, expe_config: ExperimentConfig) -> str:
     return checkpoint
 
 
-def run_server(
+async def run_server(
     folder: str,
     network: NetworkServerConfig,
     model_config: ModelConfig,
@@ -90,7 +97,7 @@ def run_server(
     config: FLRunConfig,
     expe_config: ExperimentConfig,
 ) -> None:
-    """Routine to run a FL server, called by `run_declearn_experiment`."""
+    """Routine to run a FL server, called by `quickrun`."""
     # arguments serve modularity; pylint: disable=too-many-arguments
     set_device_policy(gpu=False)
     model = get_model(folder, model_config)
@@ -100,24 +107,21 @@ def run_server(
     server = FederatedServer(
         model, network, optim, expe_config.metrics, checkpoint, logger
     )
-    server.run(config)
+    await server.async_run(config)
 
 
-def run_client(
+async def run_client(
     folder: str,
     network: NetworkClientConfig,
-    model_config: ModelConfig,
     expe_config: ExperimentConfig,
     name: str,
     paths: Dict[str, str],
 ) -> None:
-    """Routine to run a FL client, called by `run_declearn_experiment`."""
-    # arguments serve modularity; pylint: disable=too-many-arguments
+    """Routine to run a FL client, called by `quickrun`."""
     # Overwrite client name based on folder name.
     network.name = name
     # Make the model importable and disable GPU use.
     set_device_policy(gpu=False)
-    _ = get_model(folder, model_config)
     # Add checkpointer.
     checkpoint = get_checkpoint(folder, expe_config)
     checkpoint = os.path.join(checkpoint, name)
@@ -139,7 +143,7 @@ def run_client(
     client = FederatedClient(
         network, train, valid, checkpoint, logger=logger, verbose=False
     )
-    client.run()
+    await client.async_run()
 
 
 def get_toml_folder(config: str) -> Tuple[str, str]:
@@ -187,7 +191,7 @@ def locate_split_data(toml: str, folder: str) -> Dict:
 def server_to_client_network(
     network_cfg: NetworkServerConfig,
 ) -> NetworkClientConfig:
-    "Convert server network config to client network config."
+    """Convert server network config to client network config."""
     return NetworkClientConfig.from_params(
         protocol=network_cfg.protocol,
         server_uri=network_cfg.build_server().uri,
@@ -195,8 +199,8 @@ def server_to_client_network(
     )
 
 
-def quickrun(config: str) -> None:
-    """Run a server and its clients using multiprocessing.
+async def quickrun(config: str) -> None:
+    """Run a server and its clients parallelly using asyncio.
 
     The script requires to be provided with the path to a TOML file
     with all the elements required to configurate an FL experiment,
@@ -233,40 +237,39 @@ def quickrun(config: str) -> None:
     - You may refer to a more detailed MNIST example on our GitLab repository.
       See the `examples/mnist_quickrun` folder.
     """
-    # main script; pylint: disable=too-many-locals
     toml, folder = get_toml_folder(config)
-    # locate split data or split it if needed
+    # Locate split data or split it if needed.
     client_dict = locate_split_data(toml, folder)
-    # Parse toml files
+    # Parse toml files.
     ntk_server_cfg = NetworkServerConfig.from_toml(toml, False, "network")
     ntk_client_cfg = server_to_client_network(ntk_server_cfg)
     optim_cgf = FLOptimConfig.from_toml(toml, False, "optim")
     run_cfg = FLRunConfig.from_toml(toml, False, "run")
     model_cfg = ModelConfig.from_toml(toml, False, "model", True)
     expe_cfg = ExperimentConfig.from_toml(toml, False, "experiment", True)
-    # Set up a (func, args) tuple specifying the server process.
-    p_server = (
-        run_server,
-        (folder, ntk_server_cfg, model_cfg, optim_cgf, run_cfg, expe_cfg),
-    )
-    # Set up the (func, args) tuples specifying client-wise processes.
-    p_client = []
-    for name, data_dict in client_dict.items():
-        client = (
-            run_client,
-            (folder, ntk_client_cfg, model_cfg, expe_cfg, name, data_dict),
-        )
-        p_client.append(client)
-    # Run each and every process in parallel.
-    success, outputs = run_as_processes(p_server, *p_client)
-    assert success, "The FL process failed:\n" + "\n".join(
-        str(exc) for exc in outputs if isinstance(exc, RuntimeError)
+    # Set up the server and client-wise coroutines.
+    coro_server = run_server(
+        folder, ntk_server_cfg, model_cfg, optim_cgf, run_cfg, expe_cfg
     )
+    coro_clients = [
+        run_client(folder, ntk_client_cfg, expe_cfg, name, data_dict)
+        for name, data_dict in client_dict.items()
+    ]
+    # Run each and every coroutine in parallel.
+    await asyncio.gather(coro_server, *coro_clients)
+
+
+def fire_quickrun(config) -> None:
+    """Fire-wrappable caller to 'quickrun'."""
+    asyncio.run(quickrun(config))
+
+
+fire_quickrun.__doc__ = quickrun.__doc__
 
 
 def main() -> None:
     """Fire-wrapped `quickrun`."""
-    fire.Fire(quickrun)
+    fire.Fire(fire_quickrun)
 
 
 if __name__ == "__main__":
diff --git a/examples/mnist_quickrun/mnist.ipynb b/examples/mnist_quickrun/mnist.ipynb
index a55dc891..d4dc2e62 100644
--- a/examples/mnist_quickrun/mnist.ipynb
+++ b/examples/mnist_quickrun/mnist.ipynb
@@ -324,7 +324,7 @@
     "* A folder with your data, split by client. Here: `examples/mnist_quickrun/data_iid`\n",
     "* A model python file, to declare your model wrapped in a `declearn` object. Here: `examples/mnist_quickrun/model.py`.\n",
     "\n",
-    "We then only have to run the `quickrun` util with the path to the TOML file:"
+    "We then only have to run the `quickrun` coroutine with the path to the TOML file:"
    ]
   },
   {
@@ -337,7 +337,7 @@
    "source": [
     "from declearn.quickrun import quickrun\n",
     "\n",
-    "quickrun(config=\"examples/mnist_quickrun/config.toml\")"
+    "await quickrun(config=\"examples/mnist_quickrun/config.toml\")"
    ]
   },
   {
-- 
GitLab