diff --git a/README.md b/README.md index 2141b36607e70f6795f614727c47839cab461eee..01b48f0be74d3b08d5334feebc07397124fc1342 100644 --- a/README.md +++ b/README.md @@ -182,7 +182,9 @@ optim = declearn.main.FLOptimConfig.from_params( aggregator="averaging", client_opt=0.001, ) -server = declearn.main.FederatedServer(model, netwk, optim, folder="outputs") +server = declearn.main.FederatedServer( + model, netwk, optim, checkpoint="outputs" +) config = declearn.main.config.FLRunConfig.from_params( rounds=10, register={"min_clients": 1, "max_clients": 3, "timeout": 180}, @@ -206,7 +208,7 @@ train = declearn.dataset.InMemoryDataset( expose_classes=True # enable sharing of unique target values ) valid = declearn.dataset.InMemoryDataset("path/to/valid.csv", target="label") -client = declearn.main.FederatedClient(netwk, train, valid, folder="outputs") +client = declearn.main.FederatedClient(netwk, train, valid, checkpoint="outputs") client.run() ``` @@ -249,9 +251,10 @@ exposed here. - decide whether to continue, based on the number of rounds taken or on the evolution of the global loss - Finally: + - restore the model weights that yielded the lowest global loss - notify clients that training is over, so they can disconnect - and run their final routine (e.g. model saving) - - optionally save the model (through a checkpointer) + and run their final routine (e.g. save the "best" model) + - optionally checkpoint the "best" model - close the network server and end the process #### Detail of the process phases @@ -319,14 +322,15 @@ exposed here. - update model weights - perform evaluation steps based on effort constraints - step: update evaluation metrics, including the model's loss, over a batch - - checkpoint the model, then send results to the server - - optionally prevent sharing detailed metrics with the server; always - include the scalar validation loss value + - optionally checkpoint the model, local optimizer and evaluation metrics + - send results to the server: optionally prevent sharing detailed metrics; + always include the scalar validation loss value - messaging: (EvaluateRequest <-> EvaluateReply) - Server: - aggregate local loss values into a global loss metric - aggregate all other evaluation metrics and log their values - - checkpoint the model and the global loss + - optionally checkpoint the model, optimizer, aggregated evaluation + metrics and client-wise ones ### Overview of the declearn API diff --git a/declearn/main/_client.py b/declearn/main/_client.py index 4bdd509831ce96fd1cb1be1ed49ca397130ae793..f0707f1627d224d849095ef2dfa43eea4a51d2f9 100644 --- a/declearn/main/_client.py +++ b/declearn/main/_client.py @@ -5,15 +5,13 @@ import asyncio import dataclasses import logging -import os from typing import Any, Dict, Optional, Union - from declearn.communication import NetworkClientConfig, messaging from declearn.communication.api import NetworkClient from declearn.dataset import Dataset, load_dataset_from_json from declearn.main.utils import Checkpointer, TrainingManager -from declearn.utils import get_logger, json_dump +from declearn.utils import get_logger __all__ = [ @@ -31,7 +29,7 @@ class FederatedClient: netwk: Union[NetworkClient, NetworkClientConfig, Dict[str, Any], str], train_data: Union[Dataset, str], valid_data: Optional[Union[Dataset, str]] = None, - folder: Optional[str] = None, + checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None, share_metrics: bool = True, logger: Union[logging.Logger, str, None] = None, ) -> None: @@ -51,11 +49,11 @@ class FederatedClient: Optional Dataset instance wrapping validation data, or path to a JSON file from which it can be instantiated. If None, run evaluation rounds over `train_data`. - folder: str or None, default=None - Optional folder where to write out a model dump, round- - wise weights checkpoints and local validation losses. - If None, only record the loss metric and lowest-loss- - yielding weights in memory (under `self.checkpoint`). + checkpoint: Checkpointer or dict or str or None, default=None + Optional Checkpointer instance or instantiation dict to be + used so as to save round-wise model, optimizer and metrics. + If a single string is provided, treat it as the checkpoint + folder path and use default values for other parameters. share_metrics: bool, default=True Whether to share evaluation metrics with the server, or save them locally and only send the model's loss. @@ -102,9 +100,10 @@ class FederatedClient: if not (valid_data is None or isinstance(valid_data, Dataset)): raise TypeError("'valid_data' should be a Dataset or path to one.") self.valid_data = valid_data - # Record the checkpointing folder and create a Checkpointer slot. - self.folder = folder - self.checkpointer = None # type: Optional[Checkpointer] + # Assign an optional checkpointer. + if checkpoint is not None: + checkpoint = Checkpointer.from_specs(checkpoint) + self.ckptr = checkpoint # Record the metric-sharing boolean switch. self.share_metrics = bool(share_metrics) # Create a TrainingManager slot, populated at initialization phase. @@ -249,13 +248,16 @@ class FederatedClient: metrics=message.metrics, logger=self.logger, ) - # Instantiate a checkpointer and save the initial model. - self.checkpointer = Checkpointer(message.model, self.folder) - self.checkpointer.save_model() - self.checkpointer.checkpoint(float("inf")) # initial weights # If instructed to do so, await a PrivacyRequest to set up DP-SGD. if message.dpsgd: await self._initialize_dpsgd() + # Optionally checkpoint the received model and optimizer. + if self.ckptr: + self.ckptr.checkpoint( + model=self.trainmanager.model, + optimizer=self.trainmanager.optim, + first_call=True, + ) async def _initialize_dpsgd( self, @@ -309,6 +311,7 @@ class FederatedClient: # lazy-import the DPTrainingManager, that involves some optional, # heavy-loadtime dependencies; pylint: disable=import-outside-toplevel from declearn.main.privacy import DPTrainingManager + # pylint: enable=import-outside-toplevel self.trainmanager = DPTrainingManager( model=self.trainmanager.model, @@ -368,9 +371,13 @@ class FederatedClient: reply = self.trainmanager.evaluation_round(message) # Post-process the results. if isinstance(reply, messaging.EvaluationReply): # not an Error - # Checkpoint the model and record the local loss. - if self.checkpointer is not None: # True in `run` context - self.checkpointer.checkpoint(reply.loss) + # Optionnally checkpoint the model, optimizer and local loss. + if self.ckptr: + self.ckptr.checkpoint( + model=self.trainmanager.model, + optimizer=self.trainmanager.optim, + metrics=self.trainmanager.metrics.get_result(), + ) # Optionally prevent sharing metrics (save for the loss). if not self.share_metrics: reply.metrics.clear() @@ -393,17 +400,12 @@ class FederatedClient: message.rounds, message.loss, ) - if self.folder is not None: - # Save the locally-best-performing model weights. - if self.checkpointer is not None: # True in `run` context - path = os.path.join(self.folder, "best_local_weights.json") - self.logger.info("Saving best local weights in '%s'.", path) - self.checkpointer.reset_best_weights() - json_dump(self.checkpointer.model.get_weights(), path) - # Save the globally-best-performing model weights. - path = os.path.join(self.folder, "final_weights.json") - self.logger.info("Saving final weights in '%s'.", path) - json_dump(message.weights, path) + if self.ckptr: + path = f"{self.ckptr.folder}/model_state_best.json" + self.logger.info("Checkpointing final weights under %s.", path) + assert self.trainmanager is not None # for mypy + self.trainmanager.model.set_weights(message.weights) + self.ckptr.save_model(self.trainmanager.model, timestamp="best") async def cancel_training( self, diff --git a/declearn/main/_server.py b/declearn/main/_server.py index 53ad40b824f72f1bba005fd9829b48f829d68aac..a7cfc9a0ef55e097f13223a64eba9c28b2c0547a 100644 --- a/declearn/main/_server.py +++ b/declearn/main/_server.py @@ -24,7 +24,7 @@ from declearn.main.utils import ( aggregate_clients_data_info, ) from declearn.metrics import MetricInputType, MetricSet -from declearn.model.api import Model +from declearn.model.api import Model, Vector from declearn.utils import deserialize_object, get_logger @@ -47,7 +47,7 @@ class FederatedServer: netwk: Union[NetworkServer, NetworkServerConfig, Dict[str, Any], str], optim: Union[FLOptimConfig, str, Dict[str, Any]], metrics: Union[MetricSet, List[MetricInputType], None] = None, - folder: Optional[str] = None, + checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None, logger: Union[logging.Logger, str, None] = None, ) -> None: """Instantiate the orchestrating server for a federated learning task. @@ -72,11 +72,11 @@ class FederatedServer: to wrap into one, defining evaluation metrics to compute in addition to the model's loss. If None, only compute and report the model's loss. - folder: str or None, default=None - Optional folder where to write out a model dump, round- - wise weights checkpoints and global validation losses. - If None, only record the loss metric and lowest-loss- - yielding weights in memory (under `self.checkpoint`). + checkpoint: Checkpointer or dict or str or None, default=None + Optional Checkpointer instance or instantiation dict to be + used so as to save round-wise model, optimizer and metrics. + If a single string is provided, treat it as the checkpoint + folder path and use default values for other parameters. logger: logging.Logger or str or None, default=None, Logger to use, or name of a logger to set up with `declearn.utils.get_logger`. If None, use `type(self)`. @@ -125,8 +125,13 @@ class FederatedServer: self.c_opt = optim.client_opt # Assign the wrapped MetricSet. self.metrics = MetricSet.from_specs(metrics) - # Assign a model checkpointer. - self.checkpointer = Checkpointer(self.model, folder) + # Assign an optional checkpointer. + if checkpoint is not None: + checkpoint = Checkpointer.from_specs(checkpoint) + self.ckptr = checkpoint + # Set up private attributes to record the loss values and best weights. + self._loss = {} # type: Dict[int, float] + self._best = None # type: Optional[Vector] def run( self, @@ -177,8 +182,8 @@ class FederatedServer: async with self.netwk: # Conduct the initialization phase. await self.initialization(config) - self.checkpointer.save_model() - self.checkpointer.checkpoint(float("inf")) # save initial weights + if self.ckptr: + self.ckptr.checkpoint(self.model, self.optim, first_call=True) # Iteratively run training and evaluation rounds. round_i = 0 while True: @@ -478,6 +483,7 @@ class FederatedServer: EvaluateConfig dataclass instance wrapping data-batching and computational effort constraints hyper-parameters. """ + # Send evaluation requests and collect clients' replies. self.logger.info("Initiating evaluation round %s", round_i) clients = self._select_evaluation_round_participants() await self._send_evaluation_instructions(clients, round_i, valid_cfg) @@ -485,12 +491,22 @@ class FederatedServer: results = await self._collect_results( clients, messaging.EvaluationReply, "evaluation" ) + # Compute and report aggregated evaluation metrics. self.logger.info("Aggregating evaluation results.") loss, metrics = self._aggregate_evaluation_results(results) - self.logger.info("Global loss is: %s", loss) + self.logger.info("Averaged loss is: %s", loss) if metrics: - self.logger.info("Other global metrics are: %s", metrics) - self.checkpointer.checkpoint(loss) + self.logger.info( + "Other averaged scalar metrics are: %s", + {k: v for k, v in metrics.items() if isinstance(v, float)}, + ) + # Optionally checkpoint the model, optimizer and metrics. + if self.ckptr: + self._checkpoint_after_evaluation(metrics, results) + # Record the global loss, and update the kept "best" weights. + self._loss[round_i] = loss + if loss == min(self._loss.values()): + self._best = self.model.get_weights() def _select_evaluation_round_participants( self, @@ -551,6 +567,7 @@ class FederatedServer: # Case when the client reported some metrics. if reply.metrics: states = reply.metrics.copy() + # Update the global metrics based on the local ones. s_loss = states.pop("loss") loss += s_loss["current"] # type: ignore dvsr += s_loss["divisor"] # type: ignore @@ -567,6 +584,50 @@ class FederatedServer: loss = loss / dvsr return loss, metrics + def _checkpoint_after_evaluation( + self, + metrics: Dict[str, Union[float, np.ndarray]], + results: Dict[str, messaging.EvaluationReply], + ) -> None: + """Checkpoint the current model, optimizer and evaluation metrics. + + This method is meant to be called at the end of an evaluation round. + + Parameters + ---------- + metrics: dict[str, (float|np.ndarray)] + Aggregated evaluation metrics to checkpoint. + results: dict[str, EvaluationReply] + Client-wise EvaluationReply messages, based on which + `metrics` were already computed. + """ + # This method only works when a checkpointer is used. + if self.ckptr is None: + raise RuntimeError( + "`_checkpoint_after_evaluation` was called without " + "the FederatedServer having a Checkpointer." + ) + # Checkpoint the model, optimizer and global evaluation metrics. + timestamp = self.ckptr.checkpoint( + model=self.model, optimizer=self.optim, metrics=metrics + ) + # Checkpoint the client-wise metrics (or at least their loss). + # Use the same timestamp label as for global metrics and states. + local = MetricSet.from_config(self.metrics.get_config()) + for client, reply in results.items(): + if reply.metrics: + local.reset() + local.agg_states(reply.metrics) + metrics = local.get_result() + else: + metrics = {"loss": reply.loss} + self.ckptr.save_metrics( + metrics=local.get_result(), + prefix=f"metrics_{client}", + append=bool(self._loss), + timestamp=timestamp, + ) + def _keep_training( self, round_i: int, @@ -589,7 +650,7 @@ class FederatedServer: self.logger.info("Maximum number of training rounds reached.") return False if early_stop is not None: - early_stop.update(self.checkpointer.get_loss(round_i)) + early_stop.update(self._loss[round_i]) if not early_stop.keep_training: self.logger.info("Early stopping criterion reached.") return False @@ -606,11 +667,16 @@ class FederatedServer: rounds: int Number of training rounds taken until now. """ - self.checkpointer.reset_best_weights() + self.logger.info("Recovering weights that yielded the lowest loss.") message = messaging.StopTraining( - weights=self.model.get_weights(), - loss=min(self.checkpointer.get_loss(i) for i in range(rounds)), + weights=self._best or self.model.get_weights(), + loss=min(self._loss.values()) if self._loss else float("nan"), rounds=rounds, ) self.logger.info("Notifying clients that training is over.") await self.netwk.broadcast_message(message) + if self.ckptr: + path = f"{self.ckptr.folder}/model_state_best.json" + self.logger.info("Checkpointing final weights under %s.", path) + self.model.set_weights(message.weights) + self.ckptr.save_model(self.model, timestamp="best") diff --git a/declearn/main/utils/_checkpoint.py b/declearn/main/utils/_checkpoint.py index 058157a067182b62092c867d8113c910d5428f25..ecc95933c384dd780d26f386b2f1f8f035a355d1 100644 --- a/declearn/main/utils/_checkpoint.py +++ b/declearn/main/utils/_checkpoint.py @@ -2,14 +2,22 @@ """Model and metrics checkpointing util.""" +import json import os -from typing import List, Optional +from datetime import datetime +from typing import Any, Dict, List, Optional, Union import numpy as np +import pandas as pd # type: ignore -from declearn.model.api import Model, Vector -from declearn.utils import json_dump, serialize_object - +from declearn.model.api import Model +from declearn.optimizer import Optimizer +from declearn.utils import ( + deserialize_object, + json_dump, + json_load, + serialize_object, +) __all__ = [ "Checkpointer", @@ -17,81 +25,539 @@ __all__ = [ class Checkpointer: - """Model and metrics checkpointing class.""" + """Model, optimizer, and metrics checkpointing class. + + This class provides with basic checkpointing capabilities, that + enable saving a Model, an Optimizer and a dict of metric results + at various points throughout an experiment, and reloading these + checkpointed states and results. + + The key method is `checkpoint`, that enables saving all three types + of objects at once and tagging them with a single timestamp label. + Note that its `first_call` bool parameter should be set to True on + the first call, to ensure the model's and optimizer's configurations + are saved in addition to their states, and preventing the metrics + from being appended to files from a previous experiment. + + Other methods are exposed that provide with targetted saving and + loading: `save_model`, `save_optimizer`, `save_metrics` and their + counterparts `load_model`, `load_optimizer` and `load_metrics`. + Note that the latter may either be used to load metrics at a given + timestamp, or their entire history. + """ def __init__( self, - model: Model, - folder: Optional[str] = None, + folder: str, + max_history: Optional[int] = None, ) -> None: """Instantiate the checkpointer. Parameters ---------- - model: Model - Model, the config and weights from which to checkpoint. - folder: str or None, default=None - Optional folder where to write output files, such as - the loss values or the model's checkpointed weights. - If None, record losses in memory, as well as weights - having yielded the lowest loss only. + folder: str + Folder where to write output save files. + max_history: int or None, default=None + Maximum number of model and optimizer state save files to keep. + Older files are garbage-collected. If None, keep all files. """ - self.model = model self.folder = folder - if self.folder is not None: - os.makedirs(self.folder, exist_ok=True) - self._best = None # type: Optional[Vector] - self._loss = [] # type: List[float] + os.makedirs(self.folder, exist_ok=True) + if max_history is not None: + if not (isinstance(max_history, int) and max_history >= 0): + raise TypeError("'max_history' must be a positive int or None") + self.max_history = max_history - def save_model( + @classmethod + def from_specs( + cls, + inputs: Union[str, Dict[str, Any], "Checkpointer"], + ) -> "Checkpointer": + """Type-check and/or transform inputs into a Checkpointer instance. + + This classmethod is merely implemented to avoid duplicate and + boilerplate code from polluting FL orchestrating classes. + + Parameters + ---------- + specs: Checkpointer or dict[str, any] or str + Checkpointer instance to type-check, or instantiation kwargs + to parse into one. If a single string is passed, treat it as + the `folder` argument, and use default other parameters. + + Returns + ------- + checkpointer: Checkpointer + Checkpointer instance, type-checked or instantiated from inputs. + + Raises + ------ + TypeError: + If `inputs` is of unproper type. + Other exceptions may be raised when calling this class's `__init__`. + """ + if isinstance(inputs, str): + inputs = {"folder": inputs} + if isinstance(inputs, dict): + inputs = cls(**inputs) + if not isinstance(inputs, Checkpointer): + raise TypeError("'inputs' should be a Checkpointer, dict or str.") + return inputs + + # utility methods + + def garbage_collect( self, + prefix: str, ) -> None: - """Save the wrapped model's configuration to a JSON file.""" - if self.folder is not None: - path = os.path.join(self.folder, "model.json") - serialize_object(self.model).to_json(path) + """Delete files with matching prefix based on self.max_history. + + Sort files starting with `prefix` under `self.folder`, and if + there are more than `self.max_history`, delete the first ones. + Files are expected to be named as "{prefix}_{timestamp}.{ext}" + so that doing so will remove the older files. + + Parameters + ---------- + prefix: str + Prefix based on which to filter files under `self.folder`. + """ + if self.folder and self.max_history: + files = self.sort_matching_files(prefix) + for idx in range(0, len(files) - self.max_history): + os.remove(os.path.join(self.folder, files[idx])) + + def sort_matching_files( + self, + prefix: str, + ) -> List[str]: + """Return the sorted of files under `self.folder` with a given prefix. + + Parameters + ---------- + prefix: str + Prefix based on which to filter files under `self.folder`. + + Returns + ------- + fnames: list[str] + Sorted list of names of files under `self.folder` that start + with `prefix`. + """ + fnames = [f for f in os.listdir(self.folder) if f.startswith(prefix)] + return sorted(fnames) + + # saving methods + + def save_model( + self, + model: Model, + config: bool = True, + state: bool = True, + timestamp: Optional[str] = None, + ) -> Optional[str]: + """Save a Model's configuration and/or weights to JSON files. + + Also garbage-collect existing files based on self.max_history. + + Parameters + ---------- + model: Model + Model instance to save. + config: bool, default=True + Flag indicating whether to save the model's config to a file. + state: bool, default=True + Flag indicating whether to save the model's weights to a file. + timestamp: str or None, default=None + Optional preset timestamp to add as weights file suffix. + + Returns + ------- + timestamp: str or None + Timestamp string labeling the output weights file, if any. + If `states is None`, return None. + """ + model_config = ( + None + if not config + else (serialize_object(model, allow_unregistered=True).to_dict()) + ) + return self._save_object( + prefix="model", + config=model_config, + states=model.get_weights() if state else None, + timestamp=timestamp, + ) + + def save_optimizer( + self, + optimizer: Optimizer, + config: bool = True, + state: bool = True, + timestamp: Optional[str] = None, + ) -> Optional[str]: + """Save an Optimizer's configuration and/or state to JSON files. + + Parameters + ---------- + optimizer: Optimizer + Optimizer instance to save. + config: bool, default=True + Flag indicating whether to save the optimizer's config to a file. + state: bool, default=True + Flag indicating whether to save the optimizer's state to a file. + timestamp: str or None, default=None + Optional preset timestamp to add as state file suffix. + + Returns + ------- + timestamp: str or None + Timestamp string labeling the output states file, if any. + If `states is None`, return None. + """ + return self._save_object( + prefix="optimizer", + config=optimizer.get_config() if config else None, + states=optimizer.get_state() if state else None, + timestamp=timestamp, + ) + + def _save_object( + self, + prefix: str, + config: Any = None, + states: Any = None, + timestamp: Optional[str] = None, + ) -> Optional[str]: + """Shared backend for `save_model` and `save_optimizer`. + + Parameters + ---------- + prefix: str + Prefix to the created file(s). + Also used to garbage-collect state files. + config: object or None, default=None + Optional JSON-serializable config to save. + Output file will be named "{prefix}.json". + states: object or None, default=None + Optional JSON-serializable data to save. + Output file will be named "{prefix}_{timestamp}.json". + timestamp: str or None, default=None + Optional preset timestamp to add as state file suffix. + If None, generate a timestamp to use. + + Returns + ------- + timestamp: str or None + Timestamp string labeling the output states file, if any. + If `states is None`, return None. + """ + if config: + fpath = os.path.join(self.folder, f"{prefix}_config.json") + json_dump(config, fpath) + if states is not None: + if timestamp is None: + timestamp = datetime.now().strftime("%y-%m-%d_%H-%M-%S") + fpath = os.path.join( + self.folder, f"{prefix}_state_{timestamp}.json" + ) + json_dump(states, fpath) + self.garbage_collect(f"{prefix}_state") + return timestamp + return None + + def save_metrics( + self, + metrics: Dict[str, Union[float, np.ndarray]], + prefix: str = "metrics", + append: bool = True, + timestamp: Optional[str] = None, + ) -> str: + """Save a dict of metrics to a csv and a json files. + + Parameters + ---------- + metrics: dict[str, (float | np.ndarray)] + Dict storing metric values that need saving. + Note that numpy arrays will be converted to lists. + prefix: str, default="metrics" + Prefix to the output files' names. + append: bool, default=True + Whether to append to the files in case they already exist. + If False, overwrite any existing file. + timestamp: str or None, default=None + Optional preset timestamp to associate with the metrics. + + Returns + ------- + timestamp: str + Timestamp string labelling the checkpointed metrics. + """ + # Set up a timestamp and convert metrics to raw-JSON-compatible values. + if timestamp is None: + timestamp = datetime.now().strftime("%y-%m-%d_%H-%M-%S") + scores = { + key: val.tolist() if isinstance(val, np.ndarray) else float(val) + for key, val in metrics.items() + } + # Filter out scalar metrics and write them to a csv file. + scalars = {k: v for k, v in scores.items() if isinstance(v, float)} + fpath = os.path.join(self.folder, f"{prefix}.csv") + pd.DataFrame(scalars, index=[timestamp]).to_csv( + fpath, + sep=",", + mode=("a" if append else "w"), + header=not (append and os.path.isfile(fpath)), + index=True, + index_label="timestamp", + encoding="utf-8", + ) + # Write the full set of metrics to a JSON file. + jdump = json.dumps({timestamp: scores})[1:-1] # bracket-less dict + fpath = os.path.join(self.folder, f"{prefix}.json") + mode = "a" if append and os.path.isfile(fpath) else "w" + with open(fpath, mode=mode, encoding="utf-8") as file: + # First time, initialize the json file as a dict. + if mode == "w": + file.write(f"{{\n{jdump}\n}}") + # Otherwise, append the record into the existing json dict. + else: + file.truncate(file.tell() - 2) # remove trailing "\n}" + file.write(f",\n{jdump}\n}}") # append, then restore "\n}" + # Return the timestamp label. + return timestamp def checkpoint( self, - loss: float, - ) -> None: - """Checkpoint the loss value and the model's weights. + model: Optional[Model] = None, + optimizer: Optional[Optimizer] = None, + metrics: Optional[Dict[str, Union[float, np.ndarray]]] = None, + first_call: bool = False, + ) -> str: + """Checkpoint inputs, using a common timestamp label. + + Parameters + ---------- + model: Model or None, default=None + Optional Model to checkpoint. + This will call `self.save_model(config=False, state=True)`. + optimizer: Optimizer or None, default=None + Optional Optimizer to checkpoint. + This will call `self.save_optimize(config=False, state=True)`. + metrics: dict[str, (float | np.ndarray)] or None, default=None + Optional dict of metrics to checkpoint. + This will call `self.save_metrics(append=True)`. + first_call: bool, default=False + Flag indicating whether to treat this checkpoint as the first + one. If True, export the model and optimizer configurations + and/or erase pre-existing configuration and metrics files. + + Returns + ------- + timestamp: str + Timestamp string labeling the model weights and optimizer state + files, as well as the values appended to the metrics files. + """ + timestamp = datetime.now().strftime("%y-%m-%d_%H-%M-%S") + remove = [] # type: List[str] + if model: + self.save_model( + model, config=first_call, state=True, timestamp=timestamp + ) + elif first_call: + remove.append(os.path.join(self.folder, "model_config.json")) + if optimizer: + self.save_optimizer( + optimizer, config=first_call, state=True, timestamp=timestamp + ) + elif first_call: + remove.append(os.path.join(self.folder, "optimizer_config.json")) + if metrics: + append = not first_call + self.save_metrics( + metrics, prefix="metrics", append=append, timestamp=timestamp + ) + elif first_call: + remove.append(os.path.join(self.folder, "metrics.csv")) + remove.append(os.path.join(self.folder, "metrics.json")) + for path in remove: + if os.path.isfile(path): + os.remove(path) + return timestamp + + # Loading methods - If `self.folder is not None`, append the loss value to - the "losses.txt" file and record model weights under a - "weights_{i}.json" file. - Otherwise, retain the loss, and the model's weights if - the loss is at its lowest. + def load_model( + self, + model: Optional[Model] = None, + timestamp: Optional[str] = None, + load_state: bool = True, + ) -> Model: + """Instantiate a Model and/or reset its weights from a save file. Parameters ---------- - loss: float - Loss value associated with the current model state. + model: Model or None, default=None + Optional Model, the weights of which to reload. + If None, instantiate from the model config file (or raise). + timestamp: str or None, default=None + Optional timestamp string labeling the weights to reload. + If None, use the weights with the most recent timestamp. + load_state: bool, default=True + Flag specifying whether model weights are to be reloaded. + If `False`, `timestamp` will be ignored. """ - self._loss.append(loss) - if loss <= np.min(self._loss): - self._best = self.model.get_weights() - if self.folder is not None: - # Save the model's weights to a JSON file. - indx = len(self._loss) - path = os.path.join(self.folder, f"weights_{indx}.json") - json_dump(self.model.get_weights(), path) - # Append the loss to a txt file. - path = os.path.join(self.folder, "losses.txt") - mode = "a" if indx > 1 else "w" - with open(path, mode, encoding="utf-8") as file: - file.write(f"{indx}: {loss}\n") - - def reset_best_weights( + # Type-check or reload the Model from a config file. + if model is None: + fpath = os.path.join(self.folder, "model_config.json") + if not os.path.isfile(fpath): + raise FileNotFoundError( + "Cannot reload Model: config file not found." + ) + model = deserialize_object(fpath) # type: ignore + if not isinstance(model, Model): + raise TypeError( + f"The object reloaded from {fpath} is not a Model." + ) + if not isinstance(model, Model): + raise TypeError("'model' should be a Model or None.") + # Load the model weights and assign them. + if load_state: + weights = self._load_state("model", timestamp=timestamp) + model.set_weights(weights) + return model + + def load_optimizer( self, - ) -> None: - """Restore the model's weights associated with the lowest past loss.""" - if self._best is not None: - self.model.set_weights(self._best) + optimizer: Optional[Optimizer] = None, + timestamp: Optional[str] = None, + load_state: bool = True, + ) -> Optimizer: + """Instantiate an Optimizer and/or reset its state from a save file. - def get_loss( + Parameters + ---------- + optimizer: Optimizer or None, default=None + Optional Optimizer, the weights of which to reload. + If None, instantiate from the optimizer config file (or raise). + timestamp: str or None, default=None + Optional timestamp string labeling the state to reload. + If None, use the state with the most recent timestamp. + load_state: bool, default=True + Flag specifying whether optimizer state are to be reloaded. + If `False`, `timestamp` will be ignored. + """ + # Type-check or reload the Optimizer from a config file. + if optimizer is None: + fpath = os.path.join(self.folder, "optimizer_config.json") + if not os.path.isfile(fpath): + raise FileNotFoundError( + "Cannot reload Optimizer: config file not found." + ) + config = json_load(fpath) + optimizer = Optimizer.from_config(config) + if not isinstance(optimizer, Optimizer): + raise TypeError("'optimizer' should be an Optimizer or None.") + # Load the optimizer state and assign it. + if load_state: + state = self._load_state("optimizer", timestamp=timestamp) + optimizer.set_state(state) + return optimizer + + def _load_state( + self, + prefix: str, + timestamp: Optional[str] = None, + ) -> Any: + """Reload data from a state checkpoint file. + + Parameters + ---------- + prefix: str + Prefix to the target state file. + timestamp: str or None, default=None + Optional timestamp string labeling the state to reload. + If None, use the state with the most recent timestamp. + """ + if isinstance(timestamp, str): + fname = f"{prefix}_state_{timestamp}.json" + else: + files = self.sort_matching_files(f"{prefix}_state") + if not files: + raise FileNotFoundError( + f"Cannot reload {prefix} state: no state file found." + ) + fname = files[-1] + return json_load(os.path.join(self.folder, fname)) + + def load_metrics( self, - index: int, - ) -> float: - """Return the loss value recorded at a given index.""" - return self._loss[index] + prefix: str = "metrics", + timestamp: Optional[str] = None, + ) -> Dict[str, Dict[str, Union[float, np.ndarray]]]: + """Reload checkpointed metrics. + + To only reload scalar metrics as a timestamp-indexed dataframe, + see the `load_scalar_metrics` method. + + Parameters + ---------- + prefix: str, default="metrics" + Prefix to the metrics save file's name. + timestamp: str or None, default=None + Optional timestamp string labeling the metrics to reload. + If None, return all checkpointed metrics. + + Returns + ------- + metrics: dict[str, dict[str, (float | np.ndarray)]] + Dict of metrics, with `{timestamp: {key: value}}` format. + If the `timestamp` argument was not None, the first dimension + will only contain one key, which is that timestamp. + """ + fpath = os.path.join(self.folder, f"{prefix}.json") + if not os.path.isfile(fpath): + raise FileNotFoundError( + f"Cannot reload metrics: file {fpath} does not exit." + ) + with open(fpath, "r", encoding="utf-8") as file: + metrics = json.load(file) + if timestamp: + if timestamp not in metrics: + raise KeyError( + f"The reloaded metrics have no {timestamp}-labeled entry." + ) + metrics = {timestamp: metrics[timestamp]} + return { + timestamp: { + key: np.array(val) if isinstance(val, list) else val + for key, val in scores.items() + } + for timestamp, scores in metrics.items() + } + + def load_scalar_metrics( + self, + prefix: str = "metrics", + ) -> pd.DataFrame: + """Return a pandas DataFrame storing checkpointed scalar metrics. + + To reload all checkpointed metrics (i.e. scalar and numpy array ones) + see the `load_metrics` method. + + Parameters + ---------- + prefix: str, default="metrics" + Prefix to the metrics save file's name. + + Returns + ------- + metrics: pandas.DataFrame + DataFrame storing timestamp-indexed scalar metrics. + """ + fpath = os.path.join(self.folder, f"{prefix}.csv") + if not os.path.isfile(fpath): + raise FileNotFoundError( + f"Cannot reload scalar metrics: file {fpath} does not exit." + ) + return pd.read_csv(fpath, index_col="timestamp") diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py index f38754180744778212fb576a05bdf4906a36038a..ea4b2106afaef3d5ce6ce42ec19bb54733b35ebe 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/main/utils/_training.py @@ -327,7 +327,10 @@ class TrainingManager: effort = constraints.get_values() result = self.metrics.get_result() states = self.metrics.get_states() - self.logger.info("Local evaluation metrics: %s", result) + self.logger.info( + "Local scalar evaluation metrics: %s", + {k: v for k, v in result.items() if isinstance(v, float)}, + ) # Pack the result and computational effort information into a message. self.logger.info("Packing local results to be sent to the server.") return messaging.EvaluationReply( diff --git a/examples/heart-uci/client.py b/examples/heart-uci/client.py index 7dec5c26d088ec6575973f48bc7d5b12496bb3e1..7a85018e37e1ce0aab1a75379bdceae9a6b7b913 100644 --- a/examples/heart-uci/client.py +++ b/examples/heart-uci/client.py @@ -68,7 +68,8 @@ def run_client( # (5) Instantiate a FederatedClient and run it. client = FederatedClient( - network, train, valid, folder=f"{FILEDIR}/results/{name}" + # fmt: off + network, train, valid, checkpoint=f"{FILEDIR}/results/{name}" # Note: you may add `share_metrics=False` to prevent sending # evaluation metrics to the server, out of privacy concerns ) diff --git a/examples/heart-uci/data.py b/examples/heart-uci/data.py index 74a137c688bff217caddb6dfc31e909079fd9bc8..cb11b2608e9ae04bf6f4bd819961b801a2be738e 100644 --- a/examples/heart-uci/data.py +++ b/examples/heart-uci/data.py @@ -2,7 +2,7 @@ import argparse import os -from typing import List +from typing import Collection import pandas as pd @@ -29,14 +29,14 @@ DATADIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") def get_data( - dir: str = DATADIR, - names: List[str] = NAMES, + folder: str = DATADIR, + names: Collection[str] = NAMES, ) -> None: """Download and process the UCI heart disease dataset. Arguments --------- - dir: str + folder: str Path to the folder where to write output csv files. names: list[str] Names of centers, the dataset from which to download, @@ -61,8 +61,8 @@ def get_data( # Binarize the target variable. df["num"] = (df["num"] > 0).astype(int) # Export the resulting dataset to a csv file. - os.makedirs(dir, exist_ok=True) - df.to_csv(f"{dir}/{name}.csv", index=False) + os.makedirs(folder, exist_ok=True) + df.to_csv(f"{folder}/{name}.csv", index=False) # Code executed when the script is called directly. @@ -70,7 +70,7 @@ if __name__ == "__main__": # Parse commandline parameters. parser = argparse.ArgumentParser() parser.add_argument( - "--dir", + "--folder", type=str, default=DATADIR, help="folder where to write output csv files", @@ -84,4 +84,4 @@ if __name__ == "__main__": ) args = parser.parse_args() # Download and pre-process the selected dataset(s). - get_data(dir=args.dir, names=args.names) + get_data(folder=args.folder, names=args.names) diff --git a/examples/heart-uci/server.py b/examples/heart-uci/server.py index 2ef418957ae3d3b9133fe773dfe4bfacaf934a31..93eea30962a7b70f30098d3e5eb71882270f20db 100644 --- a/examples/heart-uci/server.py +++ b/examples/heart-uci/server.py @@ -5,10 +5,9 @@ import os from declearn.communication import NetworkServerConfig from declearn.main import FederatedServer -from declearn.main.config import FLRunConfig, FLOptimConfig +from declearn.main.config import FLOptimConfig, FLRunConfig from declearn.model.sklearn import SklearnSGDModel - FILEDIR = os.path.dirname(os.path.abspath(__file__)) @@ -85,7 +84,10 @@ def run_server( # f1-score and roc auc (with plot-enabling fpr/tpr curves) during # evaluation rounds. server = FederatedServer( - model, network, optim, metrics=["binary-classif", "binary-roc"] + # fmt: off + model, network, optim, + metrics=["binary-classif", "binary-roc"], + checkpoint=f"{FILEDIR}/results/server" ) # Here, we set up 20 rounds of training, with 30 samples per batch diff --git a/test/main/test_checkpoint.py b/test/main/test_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..4a3e4d8831b88d31d12ff656b46a395782852e00 --- /dev/null +++ b/test/main/test_checkpoint.py @@ -0,0 +1,415 @@ +# coding: utf-8 + +"""Unit tests for Checkpointer class.""" + +import json +import os +from pathlib import Path +from typing import Dict, Iterator, List, Union +from unittest import mock + +import numpy as np +import pandas as pd +import pytest +from sklearn.linear_model import SGDClassifier + +from declearn.main.utils import Checkpointer +from declearn.model.api import Model +from declearn.model.sklearn import SklearnSGDModel +from declearn.optimizer import Optimizer +from declearn.utils import json_load + + +# Fixtures and utils + + +@pytest.fixture(name="checkpointer") +def fixture_checkpointer(tmp_path) -> Iterator[Checkpointer]: + """Create a checkpointer within a temp dir""" + yield Checkpointer(tmp_path, 2) + + +@pytest.fixture(name="model") +def fixture_model() -> SklearnSGDModel: + """Crete a toy binary-classification model.""" + model = SklearnSGDModel(SGDClassifier()) + model.initialize({"n_features": 8, "classes": np.arange(2)}) + return model + + +@pytest.fixture(name="optimizer") +def fixture_optimizer() -> Optimizer: + """Create a toy optimizer""" + testopt = Optimizer(lrate=1.0, modules=[("momentum", {"beta": 0.95})]) + return testopt + + +@pytest.fixture(name="metrics") +def fixture_metrics() -> Dict[str, float]: + """Create a metrics fixture""" + return {"loss": 0.5} + + +def create_state_files(folder: str, type_obj: str, n_files: int) -> List[str]: + """Create test state files in checkpointer.ckpt""" + files = [ + f"{type_obj}_state_23-01-{21 + idx}_15-45-35.json" + for idx in range(n_files) + ] + for name in files: + with open(os.path.join(folder, name), "w", encoding="utf-8") as file: + json.dump({"test": "state"}, file) + return files + + +def create_config_file(checkpointer: Checkpointer, type_obj: str) -> str: + """Create test cfg files in checkpointer.ckpt""" + path = os.path.join(checkpointer.folder, f"{type_obj}_config.json") + with open(path, "w", encoding="utf-8") as file: + json.dump({"test": "config"}, file) + return f"{type_obj}_config.json" + + +# Actual tests + + +class TestCheckpointer: + + """Unit tests for Checkpointer class""" + + def test_init_default(self, tmp_path: str) -> None: + """Test `Checkpointer.__init__` with `max_history=None`.""" + checkpointer = Checkpointer(folder=tmp_path, max_history=None) + assert checkpointer.folder == tmp_path + assert Path(checkpointer.folder).is_dir() + assert checkpointer.max_history is None + + def test_init_max_history(self, tmp_path: str) -> None: + """Test `Checkpointer.__init__` with `max_history=2`.""" + checkpointer = Checkpointer(folder=tmp_path, max_history=2) + assert checkpointer.folder == tmp_path + assert Path(checkpointer.folder).is_dir() + assert checkpointer.max_history == 2 + + def test_init_fails(self, tmp_path: str) -> None: + """Test `Checkpointer.__init__` raises on negative `max_history`.""" + with pytest.raises(TypeError): + Checkpointer(folder=tmp_path, max_history=-1) + + def test_from_specs(self, tmp_path: str) -> None: + """Test that `Checkpointer.from_specs` works properly. + + This test is multi-part rather than unitary as the method + is merely boilerplate code refactored into a classmethod. + """ + tmp_path = str(tmp_path) # note: PosixPath + specs_list = [ + tmp_path, + {"folder": tmp_path, "max_history": None}, + Checkpointer(tmp_path), + ] + # Iteratively test the various types of acceptable specs. + for specs in specs_list: + ckpt = Checkpointer.from_specs(specs) # type: ignore + assert isinstance(ckpt, Checkpointer) + assert ckpt.folder == tmp_path + assert ckpt.max_history is None + # Also test that the documented TypeError is raised. + with pytest.raises(TypeError): + Checkpointer.from_specs(0) # type: ignore + + def test_garbage_collect(self, tmp_path: str) -> None: + """Test `Checkpointer.garbage_collect` when collection is needed.""" + # Set up a checkpointer with max_history=2 and 3 state files. + checkpointer = Checkpointer(folder=tmp_path, max_history=2) + names = sorted(create_state_files(tmp_path, "model", n_files=3)) + checkpointer.garbage_collect("model_state") + # Verify that the "oldest" file was removed. + files = sorted(os.listdir(checkpointer.folder)) + assert len(files) == checkpointer.max_history + assert files == names[1:] # i.e. [-max_history:] + + def test_garbage_collect_no_collection(self, tmp_path: str) -> None: + """Test `Checkpointer.garbage_collect` when collection is unneeded.""" + # Set up a checkpointer with max_history=3 and 2 state files. + checkpointer = Checkpointer(folder=tmp_path, max_history=3) + names = sorted(create_state_files(tmp_path, "model", n_files=2)) + checkpointer.garbage_collect("model_state") + # Verify that no files were removed. + files = sorted(os.listdir(checkpointer.folder)) + assert files == names + + def test_garbage_collect_infinite_history(self, tmp_path: str) -> None: + """Test `Checkpointer.garbage_collect` when `max_history=None`.""" + # Set up a checkpointer with max_history=None and 3 state files. + checkpointer = Checkpointer(folder=tmp_path, max_history=None) + names = sorted(create_state_files(tmp_path, "model", n_files=3)) + checkpointer.garbage_collect("model_state") + # Verify that no files were removed. + files = sorted(os.listdir(checkpointer.folder)) + assert files == names + + def test_sort_matching_files(self, tmp_path: str) -> None: + """Test `Checkpointer.sort_matching_files`.""" + checkpointer = Checkpointer(folder=tmp_path) + names = sorted(create_state_files(tmp_path, "model", n_files=3)) + create_state_files(tmp_path, "optimizer", n_files=2) + files = checkpointer.sort_matching_files("model_state") + assert names == files + + @pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"]) + @pytest.mark.parametrize("config", [True, False], ids=["config", "no_cfg"]) + def test_save_model( + self, tmp_path: str, model: Model, config: bool, state: bool + ) -> None: + """Test `Checkpointer.save_model` with provided parameters.""" + checkpointer = Checkpointer(folder=tmp_path) + timestamp = checkpointer.save_model(model, config, state) + # Verify config save file's existence. + cfg_path = os.path.join(checkpointer.folder, "model_config.json") + if config: + assert Path(cfg_path).is_file() + else: + assert not Path(cfg_path).is_file() + # Vertify weights save file's existence. + if state: # test state file save + assert isinstance(timestamp, str) + state_path = os.path.join( + checkpointer.folder, f"model_state_{timestamp}.json" + ) + assert Path(state_path).is_file() + else: + assert timestamp is None + assert not checkpointer.sort_matching_files("model_state") + + @pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"]) + @pytest.mark.parametrize("config", [True, False], ids=["config", "no_cfg"]) + def test_save_optimizer( + self, tmp_path: str, optimizer: Optimizer, config: bool, state: bool + ) -> None: + """Test `Checkpointer.save_optimizer` with provided parameters.""" + checkpointer = Checkpointer(folder=tmp_path) + timestamp = checkpointer.save_optimizer(optimizer, config, state) + # Verify config save file's existence. + cfg_path = os.path.join(checkpointer.folder, "optimizer_config.json") + if config: + assert Path(cfg_path).is_file() + else: + assert not Path(cfg_path).is_file() + # Vertify state save file's existence. + if state: + assert isinstance(timestamp, str) + state_path = os.path.join( + checkpointer.folder, f"optimizer_state_{timestamp}.json" + ) + assert Path(state_path).is_file() + else: + assert timestamp is None + assert not checkpointer.sort_matching_files("optimizer_state") + + def test_save_metrics(self, tmp_path: str) -> None: + """Test that `Checkpointer.save_metrics` works as expected. + + This is a multi-part test rather than unit one, to verify + that the `append` parameter and its backend work properly. + """ + # Setup for this multi-part test. + metrics = { + "foo": 42.0, + "bar": np.array([0, 1]), + } # type: Dict[str, Union[float, np.ndarray]] + checkpointer = Checkpointer(tmp_path) + csv_path = os.path.join(tmp_path, "metrics.csv") + json_path = os.path.join(tmp_path, "metrics.json") + + # Case 'append=True' but the files do not exist. + checkpointer.save_metrics(metrics, append=True, timestamp="0") + assert os.path.isfile(csv_path) + assert os.path.isfile(json_path) + scalars = pd.DataFrame({"timestamp": [0], "foo": [42.0]}) + assert (pd.read_csv(csv_path) == scalars).all(axis=None) + m_json = {"foo": 42.0, "bar": [0, 1]} + assert json_load(json_path) == {"0": m_json} + + # Case 'append=False', overwriting existing files. + checkpointer.save_metrics(metrics, append=False, timestamp="0") + assert (pd.read_csv(csv_path) == scalars).all(axis=None) + assert json_load(json_path) == {"0": m_json} + + # Case 'append=True', appending to existing files. + checkpointer.save_metrics(metrics, append=True, timestamp="1") + scalars = pd.DataFrame({"timestamp": [0, 1], "foo": [42.0, 42.0]}) + m_json = {"0": m_json, "1": m_json} + assert (pd.read_csv(csv_path) == scalars).all(axis=None) + assert json_load(json_path) == m_json + + @pytest.mark.parametrize("first", [True, False], ids=["first", "notfirst"]) + def test_checkpoint( + self, tmp_path: str, model: Model, optimizer: Optimizer, first: bool + ) -> None: + """Test that `Checkpointer.checkpoint` works as expected.""" + # Set up a checkpointer and call its checkpoint method. + checkpointer = Checkpointer(tmp_path) + metrics = {"foo": 42.0, "bar": np.array([0, 1])} + if first: # create some files that should be removed on `first_call` + create_config_file(checkpointer, "model") + timestamp = checkpointer.checkpoint( + model=model, + optimizer=optimizer, + metrics=metrics, # type: ignore + first_call=first, + ) + assert isinstance(timestamp, str) + # Verify whether config and metric files exist, as expected. + m_cfg = os.path.join(tmp_path, "model_config.json") + o_cfg = os.path.join(tmp_path, "optimizer_config.json") + if first: + assert os.path.isfile(m_cfg) + assert os.path.isfile(o_cfg) + else: + assert not os.path.isfile(m_cfg) + assert not os.path.isfile(o_cfg) + # Verify that state and metric files exist as expected. + path = os.path.join(tmp_path, f"model_state_{timestamp}.json") + assert os.path.isfile(path) + path = os.path.join(tmp_path, f"optimizer_state_{timestamp}.json") + assert os.path.isfile(path) + assert os.path.isfile(os.path.join(tmp_path, "metrics.csv")) + assert os.path.isfile(os.path.join(tmp_path, "metrics.json")) + + @pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"]) + @pytest.mark.parametrize("config", [True, False], ids=["config", "model"]) + def test_load_model( + self, tmp_path: str, model: Model, config: bool, state: bool + ) -> None: + """Test `Checkpointer.load_model` with provided parameters.""" + checkpointer = Checkpointer(tmp_path) + # Save the model (config + weights), then reload based on parameters. + timestamp = checkpointer.save_model(model, config=True, state=True) + with mock.patch.object(type(model), "set_weights") as p_set_weights: + loaded_model = checkpointer.load_model( + model=(None if config else model), + timestamp=(timestamp if config else None), # arbitrary swap + load_state=state, + ) + # Verify that the loadd model is either the input one or similar. + if config: + assert isinstance(loaded_model, type(model)) + assert loaded_model is not model + assert loaded_model.get_config() == model.get_config() + else: + assert loaded_model is model + # Verify that `set_weights` was called, with proper values. + if state: + p_set_weights.assert_called_once() + if config: + assert loaded_model.get_weights() == model.get_weights() + else: + p_set_weights.assert_not_called() + + def test_load_model_fails(self, tmp_path: str, model: Model) -> None: + """Test that `Checkpointer.load_model` raises excepted errors.""" + checkpointer = Checkpointer(tmp_path) + # Case when the weights file is missing. + checkpointer.save_model(model, config=False, state=False) + with pytest.raises(FileNotFoundError): + checkpointer.load_model(model=model, load_state=True) + # Case when the config file is mising. + checkpointer.save_model(model, config=False, state=True) + with pytest.raises(FileNotFoundError): + checkpointer.load_model(model=None) + # Case when a wrong model input is provided. + with pytest.raises(TypeError): + checkpointer.load_model(model="wrong-type") # type: ignore + + @pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"]) + @pytest.mark.parametrize("config", [True, False], ids=["config", "optim"]) + def test_load_optimizer( + self, tmp_path: str, optimizer: Optimizer, config: bool, state: bool + ) -> None: + """Test `Checkpointer.load_optimizer` with provided parameters.""" + checkpointer = Checkpointer(tmp_path) + # Save the optimizer (config + state), then reload based on parameters. + stamp = checkpointer.save_optimizer(optimizer, config=True, state=True) + with mock.patch.object(Optimizer, "set_state") as p_set_state: + loaded_optim = checkpointer.load_optimizer( + optimizer=(None if config else optimizer), + timestamp=(stamp if config else None), # arbitrary swap + load_state=state, + ) + # Verify that the loaded optimizer is either the input one or similar. + if config: + assert isinstance(loaded_optim, Optimizer) + assert loaded_optim is not optimizer + assert loaded_optim.get_config() == optimizer.get_config() + else: + assert loaded_optim is optimizer + # Verify that `set_state` was called, with proper values. + if state: + p_set_state.assert_called_once() + if config: + assert loaded_optim.get_state() == optimizer.get_state() + else: + p_set_state.assert_not_called() + + def test_load_optimizer_fails( + self, tmp_path: str, optimizer: Optimizer + ) -> None: + """Test that `Checkpointer.load_optimizer` raises excepted errors.""" + checkpointer = Checkpointer(tmp_path) + # Case when the state file is missing. + checkpointer.save_optimizer(optimizer, config=False, state=False) + with pytest.raises(FileNotFoundError): + checkpointer.load_optimizer(optimizer=optimizer, load_state=True) + # Case when the config file is mising. + checkpointer.save_optimizer(optimizer, config=False, state=True) + with pytest.raises(FileNotFoundError): + checkpointer.load_optimizer(optimizer=None) + # Case when a wrong optimizer input is provided. + with pytest.raises(TypeError): + checkpointer.load_optimizer(optimizer="wrong-type") # type: ignore + + def test_load_metrics(self, tmp_path: str) -> None: + """Test that `Checkpointer.load_metrics` works properly.""" + # Setup things by saving a couple of sets of metrics. + metrics = { + "foo": 42.0, + "bar": np.array([0, 1]), + } # type: Dict[str, Union[float, np.ndarray]] + checkpointer = Checkpointer(tmp_path) + time_0 = checkpointer.save_metrics(metrics, append=False) + time_1 = checkpointer.save_metrics(metrics, append=True) + # Test reloading all checkpointed metrics. + reloaded = checkpointer.load_metrics(timestamp=None) + assert isinstance(reloaded, dict) + assert reloaded.keys() == {time_0, time_1} + for scores in reloaded.values(): + assert isinstance(scores, dict) + assert scores.keys() == metrics.keys() + assert scores["foo"] == metrics["foo"] + assert (scores["bar"] == metrics["bar"]).all() # type: ignore + # Test reloading only metrics from one timestamp. + reloaded = checkpointer.load_metrics(timestamp=time_0) + assert isinstance(reloaded, dict) + assert reloaded.keys() == {time_0} + + def test_load_scalar_metrics(self, tmp_path: str) -> None: + """Test that `Checkpointer.load_scalar_metrics` works properly.""" + # Setup things by saving a couple of sets of metrics. + metrics = { + "foo": 42.0, + "bar": np.array([0, 1]), + } # type: Dict[str, Union[float, np.ndarray]] + checkpointer = Checkpointer(tmp_path) + time_0 = checkpointer.save_metrics(metrics, append=False) + time_1 = checkpointer.save_metrics(metrics, append=True) + expect = pd.DataFrame( + {"foo": [42.0, 42.0], "timestamp": [time_0, time_1]} + ).set_index("timestamp") + # Test reloading scalar metrics. + scores = checkpointer.load_scalar_metrics() + assert isinstance(scores, pd.DataFrame) + assert scores.index.names == expect.index.names + assert scores.columns == expect.columns + assert scores.shape == expect.shape + assert (scores == expect).all(axis=None) diff --git a/test/test_main.py b/test/test_main.py index ebda7f3463d244ed024095573a2903344b106dc5..be8ac62afd62bd3bfbc8b0f07c9eaf1bf97ffa05 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -193,7 +193,7 @@ class DeclearnTestCase: netwk = self.build_netwk_server() optim = self.build_optim_config() with tempfile.TemporaryDirectory() as folder: - server = FederatedServer(model, netwk, optim, folder=folder) + server = FederatedServer(model, netwk, optim, checkpoint=folder) config = { "rounds": self.rounds, "register": {"max_clients": self.nb_clients, "timeout": 20},