diff --git a/README.md b/README.md index b819cfbd0766673bfefc53e913fee759e58b6c0d..fa9594dcfd7cab66aad5ebc0a83f191c0e6de6a4 100644 --- a/README.md +++ b/README.md @@ -170,11 +170,12 @@ netwk = declearn.communication.NetworkServerConfig( ) strat = declearn.strategy.FedAvg() server = declearn.main.FederatedServer(model, netwk, strat, folder="outputs") -server.run( +config = declearn.main.config.FLRunConfig.from_params( rounds=10, - regst_cfg={"min_clients": 1, "max_clients": 3, "timeout": 180}, - train_cfg={"n_epochs": 5, "batch_size": 128, "drop_remainder": False} + register={"min_clients": 1, "max_clients": 3, "timeout": 180}, + training={"n_epochs": 5, "batch_size": 128, "drop_remainder": False}, ) +server.run(config) ``` ### Client-side script @@ -451,13 +452,21 @@ details on this example and on how to run it, please refer to its own - Provide the Model, Strategy and Server objects or configurations. - Optionally provide the path to a folder where to write output files (model checkpoints and global loss history). - - Call the server's `run` method, further specifying: + - Instantiate a `declearn.main.config.FLRunConfig` to specify the process: + - Maximum number of training and evaluation rounds to run. - Registration parameters: exact or min/max number of clients to have and optional timeout delay spent waiting for said clients to join. - Training parameters: data-batching parameters and effort constraints (number of local epochs and/or steps to take, and optional timeout). - Evaluation parameters: data-batching parameters and effort constraints (optional maximum number of steps (<=1 epoch) and optional timeout). + - Early-stopping parameters (optionally): patience, tolerance, etc. as + to the global model loss's evolution throughout rounds. + - Alternatively, write up a TOML configuration file that specifies all of + the former hyper-parameters. + - Call the server's `run` method, passing it the former config object (or + the path to the TOML configuration file). + #### Clients setup instructions @@ -535,11 +544,11 @@ The **coding rules** are fairly simple: and [pylint](https://pylint.pycqa.org/en/latest/) (for more general linting); do use "type: ..." and "pylint: disable=..." comments where you think it relevant, preferably with some side explanations - (see dedicated sub-sections below: [pylint](#running-black-to-format-the-code) + (see dedicated sub-sections below: [pylint](#running-pylint-to-check-the-code) and [mypy](#running-mypy-to-type-check-the-code)) - reformat your code using [black](https://github.com/psf/black); do use (sparingly) "fmt: off/on" comments when you think it relevant - (see dedicated sub-section [below](#running-pylint-to-check-the-code)) + (see dedicated sub-section [below](#running-black-to-format-the-code)) ### Unit tests and code analysis diff --git a/declearn/main/__init__.py b/declearn/main/__init__.py index b9fd7a0cb6f8418f505cbb8d23411c5fb4d02046..1007288a31526ec64c48974672d40afbeed44617 100644 --- a/declearn/main/__init__.py +++ b/declearn/main/__init__.py @@ -3,5 +3,6 @@ """Main classes implementing a Federated Learning process.""" from . import utils +from . import config from ._client import FederatedClient from ._server import FederatedServer diff --git a/declearn/main/_server.py b/declearn/main/_server.py index e1b9eb6a4571d943b85db36a55eb34570cac650b..ae7192bc8aafe785bfa7647a9a52aa39558719b5 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -4,18 +4,20 @@ import asyncio import logging -from typing import Any, Dict, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Optional, Set, Type, Union from declearn.communication import NetworkServerConfig, messaging from declearn.communication.api import Server +from declearn.main.config import ( + EvaluateConfig, + FLRunConfig, + TrainingConfig, +) from declearn.main.utils import ( AggregationError, Checkpointer, EarlyStopping, - EvaluateConfig, - RegisterConfig, - TrainingConfig, aggregate_clients_data_info, ) from declearn.model.api import Model @@ -98,83 +100,34 @@ class FederatedServer: # Assign a model checkpointer. self.checkpointer = Checkpointer(self.model, folder) - def _parse_config_dicts( - self, - regst_cfg: Union[RegisterConfig, Dict[str, Any], int], - train_cfg: Union[TrainingConfig, Dict[str, Any]], - valid_cfg: Union[EvaluateConfig, Dict[str, Any], None] = None, - ) -> Tuple[RegisterConfig, TrainingConfig, EvaluateConfig]: - """Parse input keyword arguments config dicts or dataclasses.""" - if isinstance(regst_cfg, int): - regst_cfg = RegisterConfig(min_clients=regst_cfg) - elif isinstance(regst_cfg, dict): - regst_cfg = RegisterConfig(**regst_cfg) - elif not isinstance(regst_cfg, RegisterConfig): - raise TypeError( - "'regst_cfg' should be a RegisterConfig instance or dict." - ) - if isinstance(train_cfg, dict): - train_cfg = TrainingConfig(**train_cfg) - elif not isinstance(train_cfg, TrainingConfig): - raise TypeError( - "'train_cfg' should be a TrainingConfig instance or dict." - ) - if valid_cfg is None: - valid_cfg = EvaluateConfig(batch_size=train_cfg.batch_size) - elif isinstance(valid_cfg, dict): - valid_cfg = EvaluateConfig(**valid_cfg) - elif not isinstance(valid_cfg, EvaluateConfig): - raise TypeError( - "'valid_cfg' should be a EvaluateConfig instance or dict." - ) - return regst_cfg, train_cfg, valid_cfg - def run( self, - rounds: int, - regst_cfg: Union[RegisterConfig, Dict[str, Any], int], - train_cfg: Union[TrainingConfig, Dict[str, Any]], - valid_cfg: Union[EvaluateConfig, Dict[str, Any], None] = None, - early_stop: Optional[Union[EarlyStopping, Dict[str, Any]]] = None, + config: Union[FLRunConfig, str, Dict[str, Any]], ) -> None: """Orchestrate the federated learning routine. Parameters ---------- - rounds: int - Maximum number of training rounds to perform. - regst_cfg: RegisterConfig or dict or int - Keyword arguments to specify clients-registration rules - - formatted as a dict or a declearn.main.utils.RegisterConfig - instance. Alternatively, use an int to specify the exact - number of clients that are expected to register. - train_cfg: TrainingConfig or dict - Keyword arguments to specify effort constraints and data - batching parameters for training rounds - formatted as a - dict or a declearn.main.utils.TrainingConfig instance. - valid_cfg: EvaluateConfig or dict or None, default=None - Keyword arguments to specify effort constraints and data - batching parameters for evaluation rounds. If None, use - default arguments (1 epoch over batches of same size as - for training, without shuffling nor samples dropping). - early_stop: EarlyStopping or dict or None, default=None - Optional EarlyStopping instance or configuration dict, - specifying an early-stopping rule based on the global - loss metric computed during evaluation rounds. + config: FLRunConfig or str or dict + Container instance wrapping grouped hyper-parameters that + specify the federated learning process, including clients + registration, training and validation rounds' setup, plus + an optional early-stopping criterion. + May be a str pointing to a TOML configuration file. + May be as a dict of keyword arguments to be parsed. """ - # arguments serve modularity; pylint: disable=too-many-arguments - configs = self._parse_config_dicts(regst_cfg, train_cfg, valid_cfg) - if isinstance(early_stop, dict): - early_stop = EarlyStopping(**early_stop) - if not (isinstance(early_stop, EarlyStopping) or early_stop is None): - raise TypeError("'early_stop' must be None, int or EarlyStopping.") - asyncio.run(self.async_run(rounds, configs, early_stop)) + if isinstance(config, dict): + config = FLRunConfig.from_params(**config) + if isinstance(config, str): + config = FLRunConfig.from_toml(config) # type: ignore + if not isinstance(config, FLRunConfig): + raise TypeError("'config' should be a FLRunConfig object or str.") + asyncio.run(self.async_run(config)) + async def async_run( self, - rounds: int, - configs: Tuple[RegisterConfig, TrainingConfig, EvaluateConfig], - early_stop: Optional[EarlyStopping], + config: FLRunConfig, ) -> None: """Orchestrate the federated learning routine. @@ -182,33 +135,37 @@ class FederatedServer: Parameters ---------- - rounds: int - Maximum number of training rounds to perform. - configs: (RegisterConfig, TrainingConfig, EvaluateConfig) tuple - Dataclass instances wrapping hyper-parameters specifying - clients-registration and training and evaluation rounds. - early_stop: EarlyStopping or None - Optional EarlyStopping instance adding a stopping criterion - based on the global-evaluation-loss's evolution over rounds. + config: FLRunConfig + Container instance wrapping grouped hyper-parameters that + specify the federated learning process, including clients + registration, training and validation rounds' setup, plus + an optional early-stopping criterion. """ - regst_cfg, train_cfg, valid_cfg = configs + # Instantiate the early-stopping criterion, if any. + early_stop = None # type: Optional[EarlyStopping] + if config.early_stop is not None: + early_stop = config.early_stop.instantiate() + # Start the communications server and run the FL process. async with self.netwk: - await self.initialization(regst_cfg) + # Conduct the initialization phase. + await self.initialization(config) self.checkpointer.save_model() self.checkpointer.checkpoint(float("inf")) # save initial weights + # Iteratively run training and evaluation rounds. round_i = 0 while True: round_i += 1 - await self.training_round(round_i, train_cfg) - await self.evaluation_round(round_i, valid_cfg) - if not self._keep_training(round_i, rounds, early_stop): + await self.training_round(round_i, config.training) + await self.evaluation_round(round_i, config.evaluate) + if not self._keep_training(round_i, config.rounds, early_stop): break + # Interrupt training when time comes. self.logger.info("Stopping training.") await self.stop_training(round_i) async def initialization( self, - regst_cfg: RegisterConfig, + config: FLRunConfig, ) -> None: """Orchestrate the initialization steps to set up training. @@ -217,6 +174,13 @@ class FederatedServer: Await clients to have finalized their initialization step; raise and cancel training if issues are reported back. + Parameters + ---------- + config: FLRunConfig + Container instance wrapping hyper-parameters that specify + the planned federated learning process, including clients + registration ones as a RegisterConfig dataclass instance. + Raises ------ RuntimeError: @@ -224,6 +188,8 @@ class FederatedServer: than an Empty ping-back message. Send CancelTraining to all clients before raising. """ + # Gather the RegisterConfig instance from the main FLRunConfig. + regst_cfg = config.register # Wait for clients to register and process their data information. self.logger.info("Starting clients registration process.") data_info = await self.netwk.wait_for_clients( diff --git a/declearn/main/config/__init__.py b/declearn/main/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d32c74a739b7b426985c3d262fef95a235e5b3f7 --- /dev/null +++ b/declearn/main/config/__init__.py @@ -0,0 +1,19 @@ +# coding: utf-8 + +"""Tools to specify hyper-parameters of a Federated Learning process. + +This submodule exposes dataclasses that group and document server-side +hyper-parameters that specify a Federated Learning process, as well as +a main class designed to act as a container and a parser for all these, +that may be instantiated from python objects or from a TOML file. + +In other words, `FLRunConfig` in the key class implemented here, while +the other exposed dataclasses are already articulated and used by it. +""" + +from ._dataclasses import ( + EvaluateConfig, + RegisterConfig, + TrainingConfig, +) +from ._run_config import FLRunConfig diff --git a/declearn/main/utils/_dataclasses.py b/declearn/main/config/_dataclasses.py similarity index 100% rename from declearn/main/utils/_dataclasses.py rename to declearn/main/config/_dataclasses.py diff --git a/declearn/main/config/_run_config.py b/declearn/main/config/_run_config.py new file mode 100644 index 0000000000000000000000000000000000000000..4836b93575e1c56fbccef9cb1e770021b325b2ef --- /dev/null +++ b/declearn/main/config/_run_config.py @@ -0,0 +1,106 @@ +# coding: utf-8 + +"""TOML-parsable container for Federated Learning "run" configurations.""" + +import dataclasses +from typing import Any, Optional + + +from declearn.main.utils import EarlyStopConfig +from declearn.main.config._dataclasses import ( + EvaluateConfig, + TrainingConfig, + RegisterConfig, +) +from declearn.utils import TomlConfig + + +__all__ = [ + "FLRunConfig", +] + + +@dataclasses.dataclass +class FLRunConfig(TomlConfig): + """Global container for Federated Learning "run" configurations. + + This class aims at wrapping multiple, possibly optional, sets of + hyper-parameters that parameterize a Federated Learning process, + each of which is specified through a dedicated dataclass or as a + unit python type. + + It is designed to be use by an orchestrator, e.g. the server in + the case of a centralized federated learning process. + + This class is meant to be extendable through inheritance, so as + to refine the expected fields or add some that might be used by + children (or parallel) classes of `FederatedServer` that modify + the default, centralized, federated learning process. + + Fields + ------ + rounds: int + Maximum number of training and validation rounds to perform. + register: RegisterConfig + Parameters for clients' registration (min and/or max number + of clients to expect, optional max duration of the process). + training: TrainingConfig + Parameters for training rounds, including effort constraints + and data-batching instructions. + evaluate: EvaluateConfig + Parameters for validation rounds, similar to training ones. + early_stop: EarlyStopConfig or None, default=None + Optional parameters to set up an EarlyStopping criterion, to + be leveraged so as to interrupt the federated learning process + based on the tracking of a minimized quantity (e.g. model loss). + + Instantiation classmethods + -------------------------- + from_toml: + Instantiate by parsing a TOML configuration file. + from_params: + Instantiate by parsing inputs dicts (or objects). + """ + + rounds: int + register: RegisterConfig + training: TrainingConfig + evaluate: EvaluateConfig + early_stop: Optional[EarlyStopConfig] = None # type: ignore # is a type + + @classmethod + def parse_register( + cls, + field: dataclasses.Field[RegisterConfig], + inputs: Any, + ) -> RegisterConfig: + """Field-specific parser to instantiate a RegisterConfig. + + This method supports specifying `register`: + * as a single int, translated into {"min_clients": inputs} + * as None (or missing kwarg), using default RegisterConfig() + + It otherwise routes inputs back to the `default_parser`. + """ + if inputs is None: + return RegisterConfig() + if isinstance(inputs, int): + return RegisterConfig(min_clients=1) + return cls.default_parser(field, inputs) + + @classmethod + def from_params( + cls, + **kwargs: Any, + ) -> "FLRunConfig": + # If evaluation batch size is not set, use the same as training. + # Note: if inputs have invalid formats, let the parent method fail. + evaluate = kwargs.setdefault("evaluate", {}) + if isinstance(evaluate, dict): + training = kwargs.get("training") + if isinstance(training, dict): + evaluate.setdefault("batch_size", training.get("batch_size")) + elif isinstance(training, TrainingConfig): + evaluate.setdefault("batch_size", training.batch_size) + # Delegate the rest of the work to the parent method. + return super().from_params(**kwargs) diff --git a/declearn/main/utils/__init__.py b/declearn/main/utils/__init__.py index b779241e131c8e46c9999f846c4f3b1f0af0b2bb..d3e2228a1beb113e72e5052eedf2760748d41df2 100644 --- a/declearn/main/utils/__init__.py +++ b/declearn/main/utils/__init__.py @@ -5,5 +5,4 @@ from ._checkpoint import Checkpointer from ._constraints import Constraint, ConstraintSet, TimeoutConstraint from ._data_info import AggregationError, aggregate_clients_data_info -from ._dataclasses import EvaluateConfig, RegisterConfig, TrainingConfig from ._early_stop import EarlyStopping, EarlyStopConfig diff --git a/examples/heart-uci/server.py b/examples/heart-uci/server.py index 1f25a93c45e112111dfdd2b9cb13c2a196d5a4ca..2f23e2bc4662e88badc6b67a50e9a06b1d78a12e 100644 --- a/examples/heart-uci/server.py +++ b/examples/heart-uci/server.py @@ -5,6 +5,7 @@ import os from declearn.communication import NetworkServerConfig from declearn.main import FederatedServer +from declearn.main.config import FLRunConfig from declearn.model.sklearn import SklearnSGDModel from declearn.strategy import strategy_from_config @@ -85,13 +86,14 @@ def run_server( # Here, we setup 20 rounds of training, with 30 samples per batch # during training and 50 during validation; plus an early-stopping # criterion if the global validation loss stops decreasing for 5 rounds. - server.run( + run_cfg = FLRunConfig.from_params( rounds=20, - regst_cfg={"min_clients": nb_clients}, - train_cfg={"batch_size": 30, "drop_remainder": False}, - valid_cfg={"batch_size": 50, "drop_remainder": False}, + register={"min_clients": nb_clients}, + training={"batch_size": 30, "drop_remainder": False}, + evaluate={"batch_size": 50, "drop_remainder": False}, early_stop={"tolerance": 0.0, "patience": 5, "relative": False}, ) + server.run(run_cfg) # Called when the script is called directly (using `python server.py`). diff --git a/test/test_main.py b/test/test_main.py index 085a96afc4b7641c53efee70e6532455f10c5e42..6285c31d8868511ab6fc9db9c5029f313e3b12e1 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -178,11 +178,12 @@ class DeclearnTestCase: strat = self.strategy(eta_l=0.01) with tempfile.TemporaryDirectory() as folder: server = FederatedServer(model, netwk, strat, folder=folder) - server.run( - rounds=self.rounds, - regst_cfg={"max_clients": self.nb_clients, "timeout": 20}, - train_cfg={"batch_size": 100}, - ) + config = { + "rounds": self.rounds, + "register": {"max_clients": self.nb_clients, "timeout": 20}, + "training": {"batch_size": 100}, + } + server.run(config) def run_federated_client( self,