diff --git a/README.md b/README.md index fa9594dcfd7cab66aad5ebc0a83f191c0e6de6a4..947a0c4220842c78de721c5c774173d88c15caef 100644 --- a/README.md +++ b/README.md @@ -168,7 +168,10 @@ netwk = declearn.communication.NetworkServerConfig( certificate="path/to/certificate.pem", private_key="path/to/private_key.pem" ) -strat = declearn.strategy.FedAvg() +optim = declearn.main.FLOptimConfig.from_params( + aggregator="averaging", + client_opt={"lrate": 0.001}, +) server = declearn.main.FederatedServer(model, netwk, strat, folder="outputs") config = declearn.main.config.FLRunConfig.from_params( rounds=10, @@ -418,25 +421,22 @@ details on this example and on how to run it, please refer to its own - Select the appropriate `declearn.model.api.Model` subclass to wrap it up. - Either instantiate the `Model` or provide a JSON-serialized configuration. -2. Define a Strategy: - - - Select an out-of-the-box `declearn.strategy.Strategy` subclass that - defines the aggregation and optimization strategies for the process - (_e.g._ `declearn.strategy.FedAvg` or `declearn.strategy.Scaffold`) - - Parameterize and instantiate it. - <br/>**- OR -** - - Select and parameterize a `declearn.strategy.Aggregator` (subclass) - instance to define how clients' updates are to be aggregated into - global-model updates on the server side. +2. Define a FLOptimConfig: + - Select a `declearn.aggregator.Aggregator` (subclass) instance to define + how clients' updates are to be aggregated into global-model updates on + the server side. - Parameterize a `declearn.optimizer.Optimizer` (possibly using a selected - pipeline of `declearn.optimizer.modules.OptiModule` plug-ins and/or a - pipeline of `declearn.optimizer.regularizers.Regularizer` ones) to be - used by clients to derive local step-wise updates from model gradients. + pipeline of `declearn.optimizer.modules.OptiModule` plug-ins and/or a + pipeline of `declearn.optimizer.regularizers.Regularizer` ones) to be + used by clients to derive local step-wise updates from model gradients. - Similarly, parameterize an `Optimizer` to be used by the server to - (optionally) refine the aggregated model updates before applying them. - - Wrap these three objects into a custom `Strategy` using - `declearn.strategy.strategy_from_config`. Use instantiated objects' - `get_config` method if needed to abide by the former function's specs. + (optionally) refine the aggregated model updates before applying them. + - Wrap these three objects into a `declearn.main.config.FLOptimConfig`, + possibly using its `from_config` method to specify the former three + components via configuration dicts rather than actual instances. + - Alternatively, write up a TOML configuration file that specifies these + components (note that 'aggregator' and 'server_opt' have default values + and may therefore be left unspecified). 3. Define a communication Server: @@ -449,7 +449,7 @@ details on this example and on how to run it, please refer to its own 4. Instantiate and run a FederatedServer: - Instantiate a `declearn.main.FederatedServer`: - - Provide the Model, Strategy and Server objects or configurations. + - Provide the Model, FLOptimConfig and Server objects or configurations. - Optionally provide the path to a folder where to write output files (model checkpoints and global loss history). - Instantiate a `declearn.main.config.FLRunConfig` to specify the process: diff --git a/declearn/__init__.py b/declearn/__init__.py index c21e6bc679dc661472a7cc3d0e95a4266564c561..4e2316d9c1ac4a39ba7d31629a07135244e12e75 100644 --- a/declearn/__init__.py +++ b/declearn/__init__.py @@ -31,8 +31,6 @@ The package is organized into the following submodules: Model interfacing API and implementations. * optimizer: Framework-agnostic optimizer and algorithmic plug-ins API and tools. -* strategy: - Interface to gather an Aggregator and a pair of Optimizer into a strategy. * typing: Type hinting utils, defined and exposed for code readability purposes. * utils: @@ -47,7 +45,6 @@ from . import dataset from . import model from . import optimizer from . import aggregator -from . import strategy from . import main __version__ = "2.0.0b2" diff --git a/declearn/main/_server.py b/declearn/main/_server.py index 7f340ede643c7ccb235a8c77d700c562127f676d..603bc3082267278fc0c9e902ba1010edc48678a1 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -11,6 +11,7 @@ from declearn.communication import NetworkServerConfig, messaging from declearn.communication.api import Server from declearn.main.config import ( EvaluateConfig, + FLOptimConfig, FLRunConfig, TrainingConfig, ) @@ -21,7 +22,6 @@ from declearn.main.utils import ( aggregate_clients_data_info, ) from declearn.model.api import Model -from declearn.strategy import Strategy from declearn.utils import deserialize_object, get_logger @@ -40,7 +40,7 @@ class FederatedServer: self, model: Union[Model, str, Dict[str, Any]], netwk: Union[Server, NetworkServerConfig, Dict[str, Any]], - strategy: Strategy, # future: revise Strategy, add config + optim: Union[FLOptimConfig, str, Dict[str, Any]], folder: Optional[str] = None, logger: Union[logging.Logger, str, None] = None, ) -> None: @@ -56,10 +56,11 @@ class FederatedServer: dict or dataclass enabling its instantiation. In the latter two cases, the object's default logger will be set to that of this `FederatedClient`. - strategy: Strategy - Strategy instance providing with instantiation methods for - the server's updates-aggregator, the server-side optimizer - and the clients-side one. + optim: FLOptimConfig or dict or str + FLOptimConfig instance or instantiation dict (using + the `from_params` method) or TOML configuration file path. + This object specifies the optimizers to use by the clients + and the server, as well as the client-updates aggregator. folder: str or None, default=None Optional folder where to write out a model dump, round- wise weights checkpoints and global validation losses. @@ -96,10 +97,20 @@ class FederatedServer: "or the valid configuration of one." ) self.netwk = netwk - # Assign the strategy and instantiate server-side objects. - self.strat = strategy - self.aggrg = self.strat.build_server_aggregator() - self.optim = self.strat.build_server_optimizer() + # Assign the wrapped FLOptimConfig. + if isinstance(optim, str): + optim = FLOptimConfig.from_toml(optim) + elif isinstance(optim, dict): + optim = FLOptimConfig.from_params(**optim) + if not isinstance(optim, FLOptimConfig): + raise TypeError( + "'optim' should be a declearn.main.config.FLOptimConfig " + "or a dict of parameters or the path to a TOML file from " + "which to instantiate one." + ) + self.aggrg = optim.aggregator + self.optim = optim.server_opt + self.c_opt = optim.client_opt # Assign a model checkpointer. self.checkpointer = Checkpointer(self.model, folder) @@ -202,8 +213,8 @@ class FederatedServer: # Serialize intialization information and send it to clients. message = messaging.InitRequest( model=self.model, - optim=self.strat.build_client_optimizer(), - ) # revise: strategy rather than optimizer? + optim=self.c_opt, + ) self.logger.info("Sending initialization requests to clients.") await self.netwk.broadcast_message(message) # Await a confirmation from clients that initialization went well. @@ -393,7 +404,8 @@ class FederatedServer: aux_var.setdefault(module, {})[client] = params self.optim.process_aux_var(aux_var) # Compute aggregated "gradients" (updates) and apply them to the model. - gradients = self.aggrg.aggregate( # revise: pass n_epoch / t_spent / ? + # revise: pass n_epoch / t_spent / ? + gradients = self.aggrg.aggregate( {client: result.updates for client, result in results.items()}, {client: result.n_steps for client, result in results.items()}, ) diff --git a/declearn/strategy/__init__.py b/declearn/strategy/__init__.py deleted file mode 100644 index 51830f6e2591107789d701225aa1addff5932dd8..0000000000000000000000000000000000000000 --- a/declearn/strategy/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# coding: utf-8 - -"""Federated Learning Strategy definition API and examples submodule.""" - -from ._strategy import ( - FedAvg, - Strategy, - strategy_from_config, -) -from ._strategies import ( - FedAvgM, - Scaffold, - ScaffoldM, -) diff --git a/declearn/strategy/_strategies.py b/declearn/strategy/_strategies.py deleted file mode 100644 index 3680f73b203b2863f5ab08e0557dadbf0e780bbf..0000000000000000000000000000000000000000 --- a/declearn/strategy/_strategies.py +++ /dev/null @@ -1,132 +0,0 @@ -# coding: utf-8 - -"""Strategy scratch code rewrite.""" - -from abc import ABCMeta -from typing import List - -from declearn.optimizer.modules import ( - OptiModule, - MomentumModule, - ScaffoldClientModule, - ScaffoldServerModule, -) -from declearn.strategy._strategy import FedAvg, Strategy - -__all__ = [ - "FedAvgM", - "Scaffold", - "ScaffoldM", -] - - -class FedAvgM(FedAvg): - """FedAvgM Strategy defining class. - - FedAvgM, or FedAvg with Momentum, is a Strategy extending - FedAvg to use momentum when applying aggregated updates - on the server side, as first proposed in [1]. - - References - ---------- - [1] Hsu et al., 2019 - Measuring the Effects of Non-Identical Data Distribution - for Federated Visual Classification. - https://arxiv.org/abs/1909.06335 - """ - - def __init__( - self, - eta_l: float = 1e-4, - eta_g: float = 1.0, - lam_l: float = 0.0, - lam_g: float = 0.0, - beta: float = 0.9, - ) -> None: - """Instantiate the FedAvgM Strategy. - - Parameters - ---------- - eta_l: float, default=0.0001, - Learning rate parameter of clients' optimizer. - eta_g: float, default=1. - Learning rate parameter of the server's optimizer. - Defaults to 1 so as to merely average local updates. - lam_l: float, default=0. - Weight decay parameter of clients' optimizer. - Defaults to 0 so as not to use any weight decay. - lam_g: float, default=0. - Weight decay parameter of the server's optimizer. - Defaults to 0 so as not to use any weight decay. - beta: float, default=.9 - EWMA parameter applied to aggregated updates. - See `declearn.optimizer.modules.EWMAModule`. - """ - # arguments serve modularity; pylint: disable=too-many-arguments - super().__init__(eta_l, eta_g, lam_l, lam_g) - self.beta = beta - - def _build_server_modules( - self, - ) -> List[OptiModule]: - modules = super()._build_server_modules() - modules.append(MomentumModule(self.beta)) - return modules - - -class _ScaffoldMixin(Strategy, metaclass=ABCMeta): - """Mix-in class to use SCAFFOLD on top of a base Strategy. - - SCAFFOLD, or Stochastic Controlled Averaging for Federated - Learning, is a modification of the base federated learning - process to regularize clients' drift away from the shared - model during the local training steps. - - It is implemented using a pair of OptiModule objects that - exchange and maintain state variables through training. - - See `declearn.optimizer.modules.ScaffoldClientModule` and - `ScaffoldServerModule` for details. - """ - - def _build_server_modules( - self, - ) -> List[OptiModule]: - modules = super()._build_server_modules() - modules.append(ScaffoldServerModule()) - return modules - - def _build_client_modules( - self, - ) -> List[OptiModule]: - modules = super()._build_client_modules() - modules.append(ScaffoldClientModule()) - return modules - - -class Scaffold(_ScaffoldMixin, FedAvg): - """Scaffold Strategy defining class. - - SCAFFOLD, or Stochastic Controlled Averaging for Federated Learning, - is a modification of FedAvg that applies a correction term to local - gradients at each SGD step in order to prevent clients' models from - drifting away too much from the shared model. - - It relies on the use of state variables that are maintained through - time and updated between rounds, based on the clients' sharing state - information with the server and receiving updates in return. - - This class implements SCAFFOLD on top of the base FedAvg Strategy. - See `declearn.optimizer.modules.ScaffoldClientModule` and - `ScaffoldServerModule` for details on SCAFFOLD. - See `declearn.strategy.FedAvg` for the base FedAvg class. - """ - - -class ScaffoldM(_ScaffoldMixin, FedAvgM): - """ScaffoldM Strategy defining class. - - ScaffoldM is SCAFFOLD (see `declearn.strategy.Scaffold`) combined - with the use of momentum when applying aggregated upgrades to the - global model. In other words, it is SCAFFOLD on top of FedAvgM. - """ diff --git a/declearn/strategy/_strategy.py b/declearn/strategy/_strategy.py deleted file mode 100644 index 1402fa244594bca40776f12b953fe56566fc4624..0000000000000000000000000000000000000000 --- a/declearn/strategy/_strategy.py +++ /dev/null @@ -1,205 +0,0 @@ -# coding: utf-8 - -"""Strategy scratch code rewrite.""" - -from abc import ABCMeta, abstractmethod -import dataclasses -from typing import Any, Dict, List, Union - - -from declearn.aggregator import Aggregator, AveragingAggregator -from declearn.optimizer import Optimizer -from declearn.optimizer.modules import OptiModule -from declearn.utils import deserialize_object, json_load - - -__all__ = [ - "FedAvg", - "Strategy", - "strategy_from_config", -] - - -class Strategy(metaclass=ABCMeta): - """Base class to define a client/server FL Strategy. - - This class is meant to design an API enabling the modular design - of Federated Learning strategies, which are defined by: - * an updates-aggregation algorithm - * a server-side optimization algorithm, to refine and apply - aggregated updates - * a client-side optimization algorithm, to refine and apply - step-wise gradient-based updates - * (opt.) a client-sampling policy, to select participating - clients to a given training round - - At the moment, the design of this class is *unfinished*. - Notably, in addition to the algorithmic modularity, the - future aim will be to have a modular way to instantiate - a strategy (e.g. using configuration files, authorizing - some level of client-wise overload, etc.). - """ - - @abstractmethod - def build_server_aggregator( - self, - ) -> Aggregator: - """Set up and return an Aggregator to be used by the server.""" - raise NotImplementedError - - @abstractmethod - def build_server_optimizer( - self, - ) -> Optimizer: - """Set up and return an Optimizer to be used by the server.""" - raise NotImplementedError - - def _build_server_modules( - self, - ) -> List[OptiModule]: - """Return a list of OptiModule plug-ins for the server to use.""" - return [] - - @abstractmethod - def build_client_optimizer( - self, - ) -> Optimizer: - """Set up and return an Optimizer to be used by clients.""" - raise NotImplementedError - - def _build_client_modules( - self, - ) -> List[OptiModule]: - """Return a list of OptiModule plug-ins for clients to use.""" - return [] - - # revise: add this once clients-sampling is implemented - # @abstractmethod - # def build_clients_sampler( - # self, - # ) -> ClientsSelector: - # """Docstring.""" - - -@dataclasses.dataclass -class AggregConfig: - """Dataclass specifying server aggregator config (and default).""" - - name: str = "Average" - group: str = "Aggregator" - config: Dict[str, Any] = dataclasses.field(default_factory=dict) - - -@dataclasses.dataclass -class ClientConfig: - """Dataclass specifying client-side optimizer config (and default).""" - - lrate: float = 1e-4 - w_decay: float = 0.0 - modules: List[OptiModule] = dataclasses.field(default_factory=list) - - -@dataclasses.dataclass -class ServerConfig: - """Dataclass specifying server-side optimizer config (and default).""" - - lrate: float = 1.0 - w_decay: float = 0.0 - modules: List[OptiModule] = dataclasses.field(default_factory=list) - - -def strategy_from_config( # revise: generalize this (into Strategy?) - config: Union[str, Dict[str, Any]], -) -> Strategy: - """Define a custom Strategy from a configuration file.""" - if isinstance(config, str): - config = json_load(config) - if not isinstance(config, dict): - raise TypeError("'config' should be a dict or JSON-file-stored dict.") - # Parse the configuration dict (raise if keys are unproper). - aggreg_cfg = AggregConfig(**config.get("aggregator", {})) - client_cfg = ClientConfig(**config.get("client_opt", {})) - server_cfg = ServerConfig(**config.get("server_opt", {})) - # Declare a custom class that makes use of the previous. - class CustomStrategy(Strategy): - """Custom strategy defined from a configuration file.""" - - def build_server_aggregator(self) -> Aggregator: - cfg = dataclasses.asdict(aggreg_cfg) - agg = deserialize_object(cfg) # type: ignore - if not isinstance(agg, Aggregator): - raise TypeError("Unproper object instantiated as aggregator.") - return agg - - def build_server_optimizer(self) -> Optimizer: - return Optimizer(**dataclasses.asdict(server_cfg)) - - def build_client_optimizer(self) -> Optimizer: - return Optimizer(**dataclasses.asdict(client_cfg)) - - # Instantiate from the former and return. - return CustomStrategy() - - -class FedAvg(Strategy): - """FedAvg Strategy defining class. - - FedAvg is one of the simplest Federated Learning strategies - existing. This implementation allows for a few tricks to be - used, but has default values that leave these out. - - FedAvg is characterized by: - * A simple averaging of local updates to aggregate them into - global updates (here with a default behaviour to reweight - clients' contributions based on the number of steps taken). - * The use of simple SGD by clients (here with a default step - side of 1e-4 and optional weight decay). - * The absence of refinement of averaged updates on the server - side (here with the possibility to enforce a learning rate, - aka a slowdown parameter, and use optional weight decay). - """ - - def __init__( - self, - eta_l: float = 1e-4, - eta_g: float = 1.0, - lam_l: float = 0.0, - lam_g: float = 0.0, - ) -> None: - """Instantiate the FedAvg Strategy. - - Parameters - ---------- - eta_l: float, default=0.0001, - Learning rate parameter of clients' optimizer. - eta_g: float, default=1. - Learning rate parameter of the server's optimizer. - Defaults to 1 so as to merely average local updates. - lam_l: float, default=0. - Weight decay parameter of clients' optimizer. - Defaults to 0 so as not to use any weight decay. - lam_g: float, default=0. - Weight decay parameter of the server's optimizer. - Defaults to 0 so as not to use any weight decay. - """ - self.eta_l = eta_l - self.eta_g = eta_g - self.lam_l = lam_l - self.lam_g = lam_g - - def build_client_optimizer( - self, - ) -> Optimizer: - modules = self._build_client_modules() - return Optimizer(self.eta_l, self.lam_l, modules=modules) - - def build_server_optimizer( - self, - ) -> Optimizer: - modules = self._build_server_modules() - return Optimizer(self.eta_g, self.lam_g, modules=modules) - - def build_server_aggregator( - self, - ) -> Aggregator: - return AveragingAggregator(steps_weighted=True) diff --git a/examples/heart-uci/server.py b/examples/heart-uci/server.py index 2f23e2bc4662e88badc6b67a50e9a06b1d78a12e..91cdad3f8f89f461739e533fa3b8c24002b24e3e 100644 --- a/examples/heart-uci/server.py +++ b/examples/heart-uci/server.py @@ -5,9 +5,9 @@ import os from declearn.communication import NetworkServerConfig from declearn.main import FederatedServer -from declearn.main.config import FLRunConfig +from declearn.main.config import FLRunConfig, FLOptimConfig from declearn.model.sklearn import SklearnSGDModel -from declearn.strategy import strategy_from_config + FILEDIR = os.path.dirname(os.path.abspath(__file__)) @@ -37,13 +37,13 @@ def run_server( kind="classifier", loss="log_loss", penalty="l2", alpha=0.005 ) - # (2) Define a strategy + # (2) Define an optimization strategy # Configure the aggregator to use. # Here, averaging weighted by the effective number # of local gradient descent steps taken. aggregator = { - "name": "Average", + "name": "averaging", "config": {"steps_weighted": True}, } @@ -61,13 +61,12 @@ def run_server( "modules": [("momentum", {"beta": 0.95})], } - # Wrap this up into a Strategy object$ - config = { - "aggregator": aggregator, - "client_opt": client_opt, - "server_opt": server_opt, - } - strategy = strategy_from_config(config) + # Wrap this up into an OptimizationStrategy object. + optim = FLOptimConfig.from_params( + aggregator=aggregator, + client_opt=client_opt, + server_opt=server_opt, + ) # (3) Define network communication parameters. @@ -82,7 +81,7 @@ def run_server( # (4) Instantiate and run a FederatedServer. - server = FederatedServer(model, network, strategy) + server = FederatedServer(model, network, optim) # Here, we setup 20 rounds of training, with 30 samples per batch # during training and 50 during validation; plus an early-stopping # criterion if the global validation loss stops decreasing for 5 rounds. diff --git a/test/test_main.py b/test/test_main.py index 6285c31d8868511ab6fc9db9c5029f313e3b12e1..85fb51d97d473526dd83c838bc05641c5cc013dc 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -4,7 +4,7 @@ import tempfile import warnings -from typing import Dict, Optional +from typing import Any, Dict, Optional import numpy as np import pytest @@ -23,7 +23,6 @@ from declearn.model.sklearn import SklearnSGDModel from declearn.model.tensorflow import TensorflowModel from declearn.model.torch import TorchModel from declearn.main import FederatedClient, FederatedServer -from declearn.strategy import FedAvg, FedAvgM, Scaffold, ScaffoldM from declearn.test_utils import run_as_processes @@ -46,9 +45,7 @@ class DeclearnTestCase: # arguments provide modularity; pylint: disable=too-many-arguments self.kind = kind self.framework = framework - self.strategy = { - cls.__name__: cls for cls in (FedAvg, FedAvgM, Scaffold, ScaffoldM) - }[strategy] + self.strategy = strategy self.nb_clients = nb_clients self.protocol = protocol self.use_ssl = use_ssl @@ -169,15 +166,30 @@ class DeclearnTestCase: certificate=self.ssl_cert["client_cert"] if self.use_ssl else None, ) + def build_optim_config(self) -> Dict[str, Any]: + """Return parameters to instantiate a FLOptimConfig.""" + client_modules = [] + server_modules = [] + if self.strategy == "Scaffold": + client_modules.append("scaffold-client") + server_modules.append("scaffold-server") + if self.strategy in ("FedAvgM", "ScaffoldM"): + server_modules.append("momentum") + return { + "aggregator": "averaging", + "client_opt": {"lrate": 0.01, "modules": client_modules}, + "server_opt": {"lrate": 1.0, "modules": server_modules}, + } + def run_federated_server( self, ) -> None: """Set up and run a FederatedServer.""" model = self.build_model() netwk = self.build_netwk_server() - strat = self.strategy(eta_l=0.01) + optim = self.build_optim_config() with tempfile.TemporaryDirectory() as folder: - server = FederatedServer(model, netwk, strat, folder=folder) + server = FederatedServer(model, netwk, optim, folder=folder) config = { "rounds": self.rounds, "register": {"max_clients": self.nb_clients, "timeout": 20},