From 7166e0a6194f41c9ed1cfa27d78b3c47aef82021 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Tue, 10 Oct 2023 16:01:24 +0200 Subject: [PATCH] Revise 'declearn-quickrun' backend to use asyncio. --- declearn/quickrun/_run.py | 75 +++++++++++++++-------------- examples/mnist_quickrun/mnist.ipynb | 4 +- 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/declearn/quickrun/_run.py b/declearn/quickrun/_run.py index c1b05fd2..b2d30b3c 100644 --- a/declearn/quickrun/_run.py +++ b/declearn/quickrun/_run.py @@ -30,6 +30,7 @@ The script then locally runs the FL experiment as layed out in the TOML file, using privided model and data, and stores its result in the same folder. """ +import asyncio import importlib import logging import os @@ -53,7 +54,6 @@ from declearn.test_utils import make_importable from declearn.utils import ( LOGGING_LEVEL_MAJOR, get_logger, - run_as_processes, set_device_policy, ) @@ -61,19 +61,26 @@ __all__ = ["quickrun"] def get_model(folder: str, model_config: ModelConfig) -> Model: - "Return a model instance from a model config instance" + """Return a model instance from a model config instance.""" path = model_config.model_file or os.path.join(folder, "model.py") path = os.path.abspath(path) if not os.path.isfile(path): raise FileNotFoundError("Model file not found: '{path}'.") + if not path.endswith(".py"): + raise TypeError(f"Model file at '{path}' is not a '.py' file.") with make_importable(os.path.dirname(path)): mod = importlib.import_module(os.path.basename(path)[:-3]) model = getattr(mod, model_config.model_name) + if not isinstance(model, Model): + raise TypeError( + "Imported object from the model file is required to be a " + "'declearn.model.api.Model', but is a '{type(model)}'." + ) return model def get_checkpoint(folder: str, expe_config: ExperimentConfig) -> str: - """Return the checkpoint folder, either default or as given in config""" + """Return the checkpoint folder, either default or as given in config.""" if expe_config.checkpoint: checkpoint = expe_config.checkpoint else: @@ -82,7 +89,7 @@ def get_checkpoint(folder: str, expe_config: ExperimentConfig) -> str: return checkpoint -def run_server( +async def run_server( folder: str, network: NetworkServerConfig, model_config: ModelConfig, @@ -90,7 +97,7 @@ def run_server( config: FLRunConfig, expe_config: ExperimentConfig, ) -> None: - """Routine to run a FL server, called by `run_declearn_experiment`.""" + """Routine to run a FL server, called by `quickrun`.""" # arguments serve modularity; pylint: disable=too-many-arguments set_device_policy(gpu=False) model = get_model(folder, model_config) @@ -100,24 +107,21 @@ def run_server( server = FederatedServer( model, network, optim, expe_config.metrics, checkpoint, logger ) - server.run(config) + await server.async_run(config) -def run_client( +async def run_client( folder: str, network: NetworkClientConfig, - model_config: ModelConfig, expe_config: ExperimentConfig, name: str, paths: Dict[str, str], ) -> None: - """Routine to run a FL client, called by `run_declearn_experiment`.""" - # arguments serve modularity; pylint: disable=too-many-arguments + """Routine to run a FL client, called by `quickrun`.""" # Overwrite client name based on folder name. network.name = name # Make the model importable and disable GPU use. set_device_policy(gpu=False) - _ = get_model(folder, model_config) # Add checkpointer. checkpoint = get_checkpoint(folder, expe_config) checkpoint = os.path.join(checkpoint, name) @@ -139,7 +143,7 @@ def run_client( client = FederatedClient( network, train, valid, checkpoint, logger=logger, verbose=False ) - client.run() + await client.async_run() def get_toml_folder(config: str) -> Tuple[str, str]: @@ -187,7 +191,7 @@ def locate_split_data(toml: str, folder: str) -> Dict: def server_to_client_network( network_cfg: NetworkServerConfig, ) -> NetworkClientConfig: - "Convert server network config to client network config." + """Convert server network config to client network config.""" return NetworkClientConfig.from_params( protocol=network_cfg.protocol, server_uri=network_cfg.build_server().uri, @@ -195,8 +199,8 @@ def server_to_client_network( ) -def quickrun(config: str) -> None: - """Run a server and its clients using multiprocessing. +async def quickrun(config: str) -> None: + """Run a server and its clients parallelly using asyncio. The script requires to be provided with the path to a TOML file with all the elements required to configurate an FL experiment, @@ -233,40 +237,39 @@ def quickrun(config: str) -> None: - You may refer to a more detailed MNIST example on our GitLab repository. See the `examples/mnist_quickrun` folder. """ - # main script; pylint: disable=too-many-locals toml, folder = get_toml_folder(config) - # locate split data or split it if needed + # Locate split data or split it if needed. client_dict = locate_split_data(toml, folder) - # Parse toml files + # Parse toml files. ntk_server_cfg = NetworkServerConfig.from_toml(toml, False, "network") ntk_client_cfg = server_to_client_network(ntk_server_cfg) optim_cgf = FLOptimConfig.from_toml(toml, False, "optim") run_cfg = FLRunConfig.from_toml(toml, False, "run") model_cfg = ModelConfig.from_toml(toml, False, "model", True) expe_cfg = ExperimentConfig.from_toml(toml, False, "experiment", True) - # Set up a (func, args) tuple specifying the server process. - p_server = ( - run_server, - (folder, ntk_server_cfg, model_cfg, optim_cgf, run_cfg, expe_cfg), - ) - # Set up the (func, args) tuples specifying client-wise processes. - p_client = [] - for name, data_dict in client_dict.items(): - client = ( - run_client, - (folder, ntk_client_cfg, model_cfg, expe_cfg, name, data_dict), - ) - p_client.append(client) - # Run each and every process in parallel. - success, outputs = run_as_processes(p_server, *p_client) - assert success, "The FL process failed:\n" + "\n".join( - str(exc) for exc in outputs if isinstance(exc, RuntimeError) + # Set up the server and client-wise coroutines. + coro_server = run_server( + folder, ntk_server_cfg, model_cfg, optim_cgf, run_cfg, expe_cfg ) + coro_clients = [ + run_client(folder, ntk_client_cfg, expe_cfg, name, data_dict) + for name, data_dict in client_dict.items() + ] + # Run each and every coroutine in parallel. + await asyncio.gather(coro_server, *coro_clients) + + +def fire_quickrun(config) -> None: + """Fire-wrappable caller to 'quickrun'.""" + asyncio.run(quickrun(config)) + + +fire_quickrun.__doc__ = quickrun.__doc__ def main() -> None: """Fire-wrapped `quickrun`.""" - fire.Fire(quickrun) + fire.Fire(fire_quickrun) if __name__ == "__main__": diff --git a/examples/mnist_quickrun/mnist.ipynb b/examples/mnist_quickrun/mnist.ipynb index a55dc891..d4dc2e62 100644 --- a/examples/mnist_quickrun/mnist.ipynb +++ b/examples/mnist_quickrun/mnist.ipynb @@ -324,7 +324,7 @@ "* A folder with your data, split by client. Here: `examples/mnist_quickrun/data_iid`\n", "* A model python file, to declare your model wrapped in a `declearn` object. Here: `examples/mnist_quickrun/model.py`.\n", "\n", - "We then only have to run the `quickrun` util with the path to the TOML file:" + "We then only have to run the `quickrun` coroutine with the path to the TOML file:" ] }, { @@ -337,7 +337,7 @@ "source": [ "from declearn.quickrun import quickrun\n", "\n", - "quickrun(config=\"examples/mnist_quickrun/config.toml\")" + "await quickrun(config=\"examples/mnist_quickrun/config.toml\")" ] }, { -- GitLab