Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ddc04028 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Merge branch 'revise-checkpointer' into 'main'

Revise `Checkpointer`.

See merge request !21
parents 5f23f935 4cf8b549
No related branches found
No related tags found
1 merge request!21Revise `Checkpointer`.
Pipeline #749140 waiting for manual action
...@@ -182,7 +182,9 @@ optim = declearn.main.FLOptimConfig.from_params( ...@@ -182,7 +182,9 @@ optim = declearn.main.FLOptimConfig.from_params(
aggregator="averaging", aggregator="averaging",
client_opt=0.001, client_opt=0.001,
) )
server = declearn.main.FederatedServer(model, netwk, optim, folder="outputs") server = declearn.main.FederatedServer(
model, netwk, optim, checkpoint="outputs"
)
config = declearn.main.config.FLRunConfig.from_params( config = declearn.main.config.FLRunConfig.from_params(
rounds=10, rounds=10,
register={"min_clients": 1, "max_clients": 3, "timeout": 180}, register={"min_clients": 1, "max_clients": 3, "timeout": 180},
...@@ -206,7 +208,7 @@ train = declearn.dataset.InMemoryDataset( ...@@ -206,7 +208,7 @@ train = declearn.dataset.InMemoryDataset(
expose_classes=True # enable sharing of unique target values expose_classes=True # enable sharing of unique target values
) )
valid = declearn.dataset.InMemoryDataset("path/to/valid.csv", target="label") valid = declearn.dataset.InMemoryDataset("path/to/valid.csv", target="label")
client = declearn.main.FederatedClient(netwk, train, valid, folder="outputs") client = declearn.main.FederatedClient(netwk, train, valid, checkpoint="outputs")
client.run() client.run()
``` ```
...@@ -249,9 +251,10 @@ exposed here. ...@@ -249,9 +251,10 @@ exposed here.
- decide whether to continue, based on the number of - decide whether to continue, based on the number of
rounds taken or on the evolution of the global loss rounds taken or on the evolution of the global loss
- Finally: - Finally:
- restore the model weights that yielded the lowest global loss
- notify clients that training is over, so they can disconnect - notify clients that training is over, so they can disconnect
and run their final routine (e.g. model saving) and run their final routine (e.g. save the "best" model)
- optionally save the model (through a checkpointer) - optionally checkpoint the "best" model
- close the network server and end the process - close the network server and end the process
#### Detail of the process phases #### Detail of the process phases
...@@ -319,14 +322,15 @@ exposed here. ...@@ -319,14 +322,15 @@ exposed here.
- update model weights - update model weights
- perform evaluation steps based on effort constraints - perform evaluation steps based on effort constraints
- step: update evaluation metrics, including the model's loss, over a batch - step: update evaluation metrics, including the model's loss, over a batch
- checkpoint the model, then send results to the server - optionally checkpoint the model, local optimizer and evaluation metrics
- optionally prevent sharing detailed metrics with the server; always - send results to the server: optionally prevent sharing detailed metrics;
include the scalar validation loss value always include the scalar validation loss value
- messaging: (EvaluateRequest <-> EvaluateReply) - messaging: (EvaluateRequest <-> EvaluateReply)
- Server: - Server:
- aggregate local loss values into a global loss metric - aggregate local loss values into a global loss metric
- aggregate all other evaluation metrics and log their values - aggregate all other evaluation metrics and log their values
- checkpoint the model and the global loss - optionally checkpoint the model, optimizer, aggregated evaluation
metrics and client-wise ones
### Overview of the declearn API ### Overview of the declearn API
......
...@@ -5,15 +5,13 @@ ...@@ -5,15 +5,13 @@
import asyncio import asyncio
import dataclasses import dataclasses
import logging import logging
import os
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from declearn.communication import NetworkClientConfig, messaging from declearn.communication import NetworkClientConfig, messaging
from declearn.communication.api import NetworkClient from declearn.communication.api import NetworkClient
from declearn.dataset import Dataset, load_dataset_from_json from declearn.dataset import Dataset, load_dataset_from_json
from declearn.main.utils import Checkpointer, TrainingManager from declearn.main.utils import Checkpointer, TrainingManager
from declearn.utils import get_logger, json_dump from declearn.utils import get_logger
__all__ = [ __all__ = [
...@@ -31,7 +29,7 @@ class FederatedClient: ...@@ -31,7 +29,7 @@ class FederatedClient:
netwk: Union[NetworkClient, NetworkClientConfig, Dict[str, Any], str], netwk: Union[NetworkClient, NetworkClientConfig, Dict[str, Any], str],
train_data: Union[Dataset, str], train_data: Union[Dataset, str],
valid_data: Optional[Union[Dataset, str]] = None, valid_data: Optional[Union[Dataset, str]] = None,
folder: Optional[str] = None, checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None,
share_metrics: bool = True, share_metrics: bool = True,
logger: Union[logging.Logger, str, None] = None, logger: Union[logging.Logger, str, None] = None,
) -> None: ) -> None:
...@@ -51,11 +49,11 @@ class FederatedClient: ...@@ -51,11 +49,11 @@ class FederatedClient:
Optional Dataset instance wrapping validation data, or Optional Dataset instance wrapping validation data, or
path to a JSON file from which it can be instantiated. path to a JSON file from which it can be instantiated.
If None, run evaluation rounds over `train_data`. If None, run evaluation rounds over `train_data`.
folder: str or None, default=None checkpoint: Checkpointer or dict or str or None, default=None
Optional folder where to write out a model dump, round- Optional Checkpointer instance or instantiation dict to be
wise weights checkpoints and local validation losses. used so as to save round-wise model, optimizer and metrics.
If None, only record the loss metric and lowest-loss- If a single string is provided, treat it as the checkpoint
yielding weights in memory (under `self.checkpoint`). folder path and use default values for other parameters.
share_metrics: bool, default=True share_metrics: bool, default=True
Whether to share evaluation metrics with the server, Whether to share evaluation metrics with the server,
or save them locally and only send the model's loss. or save them locally and only send the model's loss.
...@@ -102,9 +100,10 @@ class FederatedClient: ...@@ -102,9 +100,10 @@ class FederatedClient:
if not (valid_data is None or isinstance(valid_data, Dataset)): if not (valid_data is None or isinstance(valid_data, Dataset)):
raise TypeError("'valid_data' should be a Dataset or path to one.") raise TypeError("'valid_data' should be a Dataset or path to one.")
self.valid_data = valid_data self.valid_data = valid_data
# Record the checkpointing folder and create a Checkpointer slot. # Assign an optional checkpointer.
self.folder = folder if checkpoint is not None:
self.checkpointer = None # type: Optional[Checkpointer] checkpoint = Checkpointer.from_specs(checkpoint)
self.ckptr = checkpoint
# Record the metric-sharing boolean switch. # Record the metric-sharing boolean switch.
self.share_metrics = bool(share_metrics) self.share_metrics = bool(share_metrics)
# Create a TrainingManager slot, populated at initialization phase. # Create a TrainingManager slot, populated at initialization phase.
...@@ -249,13 +248,16 @@ class FederatedClient: ...@@ -249,13 +248,16 @@ class FederatedClient:
metrics=message.metrics, metrics=message.metrics,
logger=self.logger, logger=self.logger,
) )
# Instantiate a checkpointer and save the initial model.
self.checkpointer = Checkpointer(message.model, self.folder)
self.checkpointer.save_model()
self.checkpointer.checkpoint(float("inf")) # initial weights
# If instructed to do so, await a PrivacyRequest to set up DP-SGD. # If instructed to do so, await a PrivacyRequest to set up DP-SGD.
if message.dpsgd: if message.dpsgd:
await self._initialize_dpsgd() await self._initialize_dpsgd()
# Optionally checkpoint the received model and optimizer.
if self.ckptr:
self.ckptr.checkpoint(
model=self.trainmanager.model,
optimizer=self.trainmanager.optim,
first_call=True,
)
async def _initialize_dpsgd( async def _initialize_dpsgd(
self, self,
...@@ -309,6 +311,7 @@ class FederatedClient: ...@@ -309,6 +311,7 @@ class FederatedClient:
# lazy-import the DPTrainingManager, that involves some optional, # lazy-import the DPTrainingManager, that involves some optional,
# heavy-loadtime dependencies; pylint: disable=import-outside-toplevel # heavy-loadtime dependencies; pylint: disable=import-outside-toplevel
from declearn.main.privacy import DPTrainingManager from declearn.main.privacy import DPTrainingManager
# pylint: enable=import-outside-toplevel # pylint: enable=import-outside-toplevel
self.trainmanager = DPTrainingManager( self.trainmanager = DPTrainingManager(
model=self.trainmanager.model, model=self.trainmanager.model,
...@@ -368,9 +371,13 @@ class FederatedClient: ...@@ -368,9 +371,13 @@ class FederatedClient:
reply = self.trainmanager.evaluation_round(message) reply = self.trainmanager.evaluation_round(message)
# Post-process the results. # Post-process the results.
if isinstance(reply, messaging.EvaluationReply): # not an Error if isinstance(reply, messaging.EvaluationReply): # not an Error
# Checkpoint the model and record the local loss. # Optionnally checkpoint the model, optimizer and local loss.
if self.checkpointer is not None: # True in `run` context if self.ckptr:
self.checkpointer.checkpoint(reply.loss) self.ckptr.checkpoint(
model=self.trainmanager.model,
optimizer=self.trainmanager.optim,
metrics=self.trainmanager.metrics.get_result(),
)
# Optionally prevent sharing metrics (save for the loss). # Optionally prevent sharing metrics (save for the loss).
if not self.share_metrics: if not self.share_metrics:
reply.metrics.clear() reply.metrics.clear()
...@@ -393,17 +400,12 @@ class FederatedClient: ...@@ -393,17 +400,12 @@ class FederatedClient:
message.rounds, message.rounds,
message.loss, message.loss,
) )
if self.folder is not None: if self.ckptr:
# Save the locally-best-performing model weights. path = f"{self.ckptr.folder}/model_state_best.json"
if self.checkpointer is not None: # True in `run` context self.logger.info("Checkpointing final weights under %s.", path)
path = os.path.join(self.folder, "best_local_weights.json") assert self.trainmanager is not None # for mypy
self.logger.info("Saving best local weights in '%s'.", path) self.trainmanager.model.set_weights(message.weights)
self.checkpointer.reset_best_weights() self.ckptr.save_model(self.trainmanager.model, timestamp="best")
json_dump(self.checkpointer.model.get_weights(), path)
# Save the globally-best-performing model weights.
path = os.path.join(self.folder, "final_weights.json")
self.logger.info("Saving final weights in '%s'.", path)
json_dump(message.weights, path)
async def cancel_training( async def cancel_training(
self, self,
......
...@@ -24,7 +24,7 @@ from declearn.main.utils import ( ...@@ -24,7 +24,7 @@ from declearn.main.utils import (
aggregate_clients_data_info, aggregate_clients_data_info,
) )
from declearn.metrics import MetricInputType, MetricSet from declearn.metrics import MetricInputType, MetricSet
from declearn.model.api import Model from declearn.model.api import Model, Vector
from declearn.utils import deserialize_object, get_logger from declearn.utils import deserialize_object, get_logger
...@@ -47,7 +47,7 @@ class FederatedServer: ...@@ -47,7 +47,7 @@ class FederatedServer:
netwk: Union[NetworkServer, NetworkServerConfig, Dict[str, Any], str], netwk: Union[NetworkServer, NetworkServerConfig, Dict[str, Any], str],
optim: Union[FLOptimConfig, str, Dict[str, Any]], optim: Union[FLOptimConfig, str, Dict[str, Any]],
metrics: Union[MetricSet, List[MetricInputType], None] = None, metrics: Union[MetricSet, List[MetricInputType], None] = None,
folder: Optional[str] = None, checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None,
logger: Union[logging.Logger, str, None] = None, logger: Union[logging.Logger, str, None] = None,
) -> None: ) -> None:
"""Instantiate the orchestrating server for a federated learning task. """Instantiate the orchestrating server for a federated learning task.
...@@ -72,11 +72,11 @@ class FederatedServer: ...@@ -72,11 +72,11 @@ class FederatedServer:
to wrap into one, defining evaluation metrics to compute in to wrap into one, defining evaluation metrics to compute in
addition to the model's loss. addition to the model's loss.
If None, only compute and report the model's loss. If None, only compute and report the model's loss.
folder: str or None, default=None checkpoint: Checkpointer or dict or str or None, default=None
Optional folder where to write out a model dump, round- Optional Checkpointer instance or instantiation dict to be
wise weights checkpoints and global validation losses. used so as to save round-wise model, optimizer and metrics.
If None, only record the loss metric and lowest-loss- If a single string is provided, treat it as the checkpoint
yielding weights in memory (under `self.checkpoint`). folder path and use default values for other parameters.
logger: logging.Logger or str or None, default=None, logger: logging.Logger or str or None, default=None,
Logger to use, or name of a logger to set up with Logger to use, or name of a logger to set up with
`declearn.utils.get_logger`. If None, use `type(self)`. `declearn.utils.get_logger`. If None, use `type(self)`.
...@@ -125,8 +125,13 @@ class FederatedServer: ...@@ -125,8 +125,13 @@ class FederatedServer:
self.c_opt = optim.client_opt self.c_opt = optim.client_opt
# Assign the wrapped MetricSet. # Assign the wrapped MetricSet.
self.metrics = MetricSet.from_specs(metrics) self.metrics = MetricSet.from_specs(metrics)
# Assign a model checkpointer. # Assign an optional checkpointer.
self.checkpointer = Checkpointer(self.model, folder) if checkpoint is not None:
checkpoint = Checkpointer.from_specs(checkpoint)
self.ckptr = checkpoint
# Set up private attributes to record the loss values and best weights.
self._loss = {} # type: Dict[int, float]
self._best = None # type: Optional[Vector]
def run( def run(
self, self,
...@@ -177,8 +182,8 @@ class FederatedServer: ...@@ -177,8 +182,8 @@ class FederatedServer:
async with self.netwk: async with self.netwk:
# Conduct the initialization phase. # Conduct the initialization phase.
await self.initialization(config) await self.initialization(config)
self.checkpointer.save_model() if self.ckptr:
self.checkpointer.checkpoint(float("inf")) # save initial weights self.ckptr.checkpoint(self.model, self.optim, first_call=True)
# Iteratively run training and evaluation rounds. # Iteratively run training and evaluation rounds.
round_i = 0 round_i = 0
while True: while True:
...@@ -478,6 +483,7 @@ class FederatedServer: ...@@ -478,6 +483,7 @@ class FederatedServer:
EvaluateConfig dataclass instance wrapping data-batching EvaluateConfig dataclass instance wrapping data-batching
and computational effort constraints hyper-parameters. and computational effort constraints hyper-parameters.
""" """
# Send evaluation requests and collect clients' replies.
self.logger.info("Initiating evaluation round %s", round_i) self.logger.info("Initiating evaluation round %s", round_i)
clients = self._select_evaluation_round_participants() clients = self._select_evaluation_round_participants()
await self._send_evaluation_instructions(clients, round_i, valid_cfg) await self._send_evaluation_instructions(clients, round_i, valid_cfg)
...@@ -485,12 +491,22 @@ class FederatedServer: ...@@ -485,12 +491,22 @@ class FederatedServer:
results = await self._collect_results( results = await self._collect_results(
clients, messaging.EvaluationReply, "evaluation" clients, messaging.EvaluationReply, "evaluation"
) )
# Compute and report aggregated evaluation metrics.
self.logger.info("Aggregating evaluation results.") self.logger.info("Aggregating evaluation results.")
loss, metrics = self._aggregate_evaluation_results(results) loss, metrics = self._aggregate_evaluation_results(results)
self.logger.info("Global loss is: %s", loss) self.logger.info("Averaged loss is: %s", loss)
if metrics: if metrics:
self.logger.info("Other global metrics are: %s", metrics) self.logger.info(
self.checkpointer.checkpoint(loss) "Other averaged scalar metrics are: %s",
{k: v for k, v in metrics.items() if isinstance(v, float)},
)
# Optionally checkpoint the model, optimizer and metrics.
if self.ckptr:
self._checkpoint_after_evaluation(metrics, results)
# Record the global loss, and update the kept "best" weights.
self._loss[round_i] = loss
if loss == min(self._loss.values()):
self._best = self.model.get_weights()
def _select_evaluation_round_participants( def _select_evaluation_round_participants(
self, self,
...@@ -551,6 +567,7 @@ class FederatedServer: ...@@ -551,6 +567,7 @@ class FederatedServer:
# Case when the client reported some metrics. # Case when the client reported some metrics.
if reply.metrics: if reply.metrics:
states = reply.metrics.copy() states = reply.metrics.copy()
# Update the global metrics based on the local ones.
s_loss = states.pop("loss") s_loss = states.pop("loss")
loss += s_loss["current"] # type: ignore loss += s_loss["current"] # type: ignore
dvsr += s_loss["divisor"] # type: ignore dvsr += s_loss["divisor"] # type: ignore
...@@ -567,6 +584,50 @@ class FederatedServer: ...@@ -567,6 +584,50 @@ class FederatedServer:
loss = loss / dvsr loss = loss / dvsr
return loss, metrics return loss, metrics
def _checkpoint_after_evaluation(
self,
metrics: Dict[str, Union[float, np.ndarray]],
results: Dict[str, messaging.EvaluationReply],
) -> None:
"""Checkpoint the current model, optimizer and evaluation metrics.
This method is meant to be called at the end of an evaluation round.
Parameters
----------
metrics: dict[str, (float|np.ndarray)]
Aggregated evaluation metrics to checkpoint.
results: dict[str, EvaluationReply]
Client-wise EvaluationReply messages, based on which
`metrics` were already computed.
"""
# This method only works when a checkpointer is used.
if self.ckptr is None:
raise RuntimeError(
"`_checkpoint_after_evaluation` was called without "
"the FederatedServer having a Checkpointer."
)
# Checkpoint the model, optimizer and global evaluation metrics.
timestamp = self.ckptr.checkpoint(
model=self.model, optimizer=self.optim, metrics=metrics
)
# Checkpoint the client-wise metrics (or at least their loss).
# Use the same timestamp label as for global metrics and states.
local = MetricSet.from_config(self.metrics.get_config())
for client, reply in results.items():
if reply.metrics:
local.reset()
local.agg_states(reply.metrics)
metrics = local.get_result()
else:
metrics = {"loss": reply.loss}
self.ckptr.save_metrics(
metrics=local.get_result(),
prefix=f"metrics_{client}",
append=bool(self._loss),
timestamp=timestamp,
)
def _keep_training( def _keep_training(
self, self,
round_i: int, round_i: int,
...@@ -589,7 +650,7 @@ class FederatedServer: ...@@ -589,7 +650,7 @@ class FederatedServer:
self.logger.info("Maximum number of training rounds reached.") self.logger.info("Maximum number of training rounds reached.")
return False return False
if early_stop is not None: if early_stop is not None:
early_stop.update(self.checkpointer.get_loss(round_i)) early_stop.update(self._loss[round_i])
if not early_stop.keep_training: if not early_stop.keep_training:
self.logger.info("Early stopping criterion reached.") self.logger.info("Early stopping criterion reached.")
return False return False
...@@ -606,11 +667,16 @@ class FederatedServer: ...@@ -606,11 +667,16 @@ class FederatedServer:
rounds: int rounds: int
Number of training rounds taken until now. Number of training rounds taken until now.
""" """
self.checkpointer.reset_best_weights() self.logger.info("Recovering weights that yielded the lowest loss.")
message = messaging.StopTraining( message = messaging.StopTraining(
weights=self.model.get_weights(), weights=self._best or self.model.get_weights(),
loss=min(self.checkpointer.get_loss(i) for i in range(rounds)), loss=min(self._loss.values()) if self._loss else float("nan"),
rounds=rounds, rounds=rounds,
) )
self.logger.info("Notifying clients that training is over.") self.logger.info("Notifying clients that training is over.")
await self.netwk.broadcast_message(message) await self.netwk.broadcast_message(message)
if self.ckptr:
path = f"{self.ckptr.folder}/model_state_best.json"
self.logger.info("Checkpointing final weights under %s.", path)
self.model.set_weights(message.weights)
self.ckptr.save_model(self.model, timestamp="best")
This diff is collapsed.
...@@ -327,7 +327,10 @@ class TrainingManager: ...@@ -327,7 +327,10 @@ class TrainingManager:
effort = constraints.get_values() effort = constraints.get_values()
result = self.metrics.get_result() result = self.metrics.get_result()
states = self.metrics.get_states() states = self.metrics.get_states()
self.logger.info("Local evaluation metrics: %s", result) self.logger.info(
"Local scalar evaluation metrics: %s",
{k: v for k, v in result.items() if isinstance(v, float)},
)
# Pack the result and computational effort information into a message. # Pack the result and computational effort information into a message.
self.logger.info("Packing local results to be sent to the server.") self.logger.info("Packing local results to be sent to the server.")
return messaging.EvaluationReply( return messaging.EvaluationReply(
......
...@@ -68,7 +68,8 @@ def run_client( ...@@ -68,7 +68,8 @@ def run_client(
# (5) Instantiate a FederatedClient and run it. # (5) Instantiate a FederatedClient and run it.
client = FederatedClient( client = FederatedClient(
network, train, valid, folder=f"{FILEDIR}/results/{name}" # fmt: off
network, train, valid, checkpoint=f"{FILEDIR}/results/{name}"
# Note: you may add `share_metrics=False` to prevent sending # Note: you may add `share_metrics=False` to prevent sending
# evaluation metrics to the server, out of privacy concerns # evaluation metrics to the server, out of privacy concerns
) )
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import argparse import argparse
import os import os
from typing import List from typing import Collection
import pandas as pd import pandas as pd
...@@ -29,14 +29,14 @@ DATADIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") ...@@ -29,14 +29,14 @@ DATADIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
def get_data( def get_data(
dir: str = DATADIR, folder: str = DATADIR,
names: List[str] = NAMES, names: Collection[str] = NAMES,
) -> None: ) -> None:
"""Download and process the UCI heart disease dataset. """Download and process the UCI heart disease dataset.
Arguments Arguments
--------- ---------
dir: str folder: str
Path to the folder where to write output csv files. Path to the folder where to write output csv files.
names: list[str] names: list[str]
Names of centers, the dataset from which to download, Names of centers, the dataset from which to download,
...@@ -61,8 +61,8 @@ def get_data( ...@@ -61,8 +61,8 @@ def get_data(
# Binarize the target variable. # Binarize the target variable.
df["num"] = (df["num"] > 0).astype(int) df["num"] = (df["num"] > 0).astype(int)
# Export the resulting dataset to a csv file. # Export the resulting dataset to a csv file.
os.makedirs(dir, exist_ok=True) os.makedirs(folder, exist_ok=True)
df.to_csv(f"{dir}/{name}.csv", index=False) df.to_csv(f"{folder}/{name}.csv", index=False)
# Code executed when the script is called directly. # Code executed when the script is called directly.
...@@ -70,7 +70,7 @@ if __name__ == "__main__": ...@@ -70,7 +70,7 @@ if __name__ == "__main__":
# Parse commandline parameters. # Parse commandline parameters.
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--dir", "--folder",
type=str, type=str,
default=DATADIR, default=DATADIR,
help="folder where to write output csv files", help="folder where to write output csv files",
...@@ -84,4 +84,4 @@ if __name__ == "__main__": ...@@ -84,4 +84,4 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Download and pre-process the selected dataset(s). # Download and pre-process the selected dataset(s).
get_data(dir=args.dir, names=args.names) get_data(folder=args.folder, names=args.names)
...@@ -5,10 +5,9 @@ import os ...@@ -5,10 +5,9 @@ import os
from declearn.communication import NetworkServerConfig from declearn.communication import NetworkServerConfig
from declearn.main import FederatedServer from declearn.main import FederatedServer
from declearn.main.config import FLRunConfig, FLOptimConfig from declearn.main.config import FLOptimConfig, FLRunConfig
from declearn.model.sklearn import SklearnSGDModel from declearn.model.sklearn import SklearnSGDModel
FILEDIR = os.path.dirname(os.path.abspath(__file__)) FILEDIR = os.path.dirname(os.path.abspath(__file__))
...@@ -85,7 +84,10 @@ def run_server( ...@@ -85,7 +84,10 @@ def run_server(
# f1-score and roc auc (with plot-enabling fpr/tpr curves) during # f1-score and roc auc (with plot-enabling fpr/tpr curves) during
# evaluation rounds. # evaluation rounds.
server = FederatedServer( server = FederatedServer(
model, network, optim, metrics=["binary-classif", "binary-roc"] # fmt: off
model, network, optim,
metrics=["binary-classif", "binary-roc"],
checkpoint=f"{FILEDIR}/results/server"
) )
# Here, we set up 20 rounds of training, with 30 samples per batch # Here, we set up 20 rounds of training, with 30 samples per batch
......
# coding: utf-8
"""Unit tests for Checkpointer class."""
import json
import os
from pathlib import Path
from typing import Dict, Iterator, List, Union
from unittest import mock
import numpy as np
import pandas as pd
import pytest
from sklearn.linear_model import SGDClassifier
from declearn.main.utils import Checkpointer
from declearn.model.api import Model
from declearn.model.sklearn import SklearnSGDModel
from declearn.optimizer import Optimizer
from declearn.utils import json_load
# Fixtures and utils
@pytest.fixture(name="checkpointer")
def fixture_checkpointer(tmp_path) -> Iterator[Checkpointer]:
"""Create a checkpointer within a temp dir"""
yield Checkpointer(tmp_path, 2)
@pytest.fixture(name="model")
def fixture_model() -> SklearnSGDModel:
"""Crete a toy binary-classification model."""
model = SklearnSGDModel(SGDClassifier())
model.initialize({"n_features": 8, "classes": np.arange(2)})
return model
@pytest.fixture(name="optimizer")
def fixture_optimizer() -> Optimizer:
"""Create a toy optimizer"""
testopt = Optimizer(lrate=1.0, modules=[("momentum", {"beta": 0.95})])
return testopt
@pytest.fixture(name="metrics")
def fixture_metrics() -> Dict[str, float]:
"""Create a metrics fixture"""
return {"loss": 0.5}
def create_state_files(folder: str, type_obj: str, n_files: int) -> List[str]:
"""Create test state files in checkpointer.ckpt"""
files = [
f"{type_obj}_state_23-01-{21 + idx}_15-45-35.json"
for idx in range(n_files)
]
for name in files:
with open(os.path.join(folder, name), "w", encoding="utf-8") as file:
json.dump({"test": "state"}, file)
return files
def create_config_file(checkpointer: Checkpointer, type_obj: str) -> str:
"""Create test cfg files in checkpointer.ckpt"""
path = os.path.join(checkpointer.folder, f"{type_obj}_config.json")
with open(path, "w", encoding="utf-8") as file:
json.dump({"test": "config"}, file)
return f"{type_obj}_config.json"
# Actual tests
class TestCheckpointer:
"""Unit tests for Checkpointer class"""
def test_init_default(self, tmp_path: str) -> None:
"""Test `Checkpointer.__init__` with `max_history=None`."""
checkpointer = Checkpointer(folder=tmp_path, max_history=None)
assert checkpointer.folder == tmp_path
assert Path(checkpointer.folder).is_dir()
assert checkpointer.max_history is None
def test_init_max_history(self, tmp_path: str) -> None:
"""Test `Checkpointer.__init__` with `max_history=2`."""
checkpointer = Checkpointer(folder=tmp_path, max_history=2)
assert checkpointer.folder == tmp_path
assert Path(checkpointer.folder).is_dir()
assert checkpointer.max_history == 2
def test_init_fails(self, tmp_path: str) -> None:
"""Test `Checkpointer.__init__` raises on negative `max_history`."""
with pytest.raises(TypeError):
Checkpointer(folder=tmp_path, max_history=-1)
def test_from_specs(self, tmp_path: str) -> None:
"""Test that `Checkpointer.from_specs` works properly.
This test is multi-part rather than unitary as the method
is merely boilerplate code refactored into a classmethod.
"""
tmp_path = str(tmp_path) # note: PosixPath
specs_list = [
tmp_path,
{"folder": tmp_path, "max_history": None},
Checkpointer(tmp_path),
]
# Iteratively test the various types of acceptable specs.
for specs in specs_list:
ckpt = Checkpointer.from_specs(specs) # type: ignore
assert isinstance(ckpt, Checkpointer)
assert ckpt.folder == tmp_path
assert ckpt.max_history is None
# Also test that the documented TypeError is raised.
with pytest.raises(TypeError):
Checkpointer.from_specs(0) # type: ignore
def test_garbage_collect(self, tmp_path: str) -> None:
"""Test `Checkpointer.garbage_collect` when collection is needed."""
# Set up a checkpointer with max_history=2 and 3 state files.
checkpointer = Checkpointer(folder=tmp_path, max_history=2)
names = sorted(create_state_files(tmp_path, "model", n_files=3))
checkpointer.garbage_collect("model_state")
# Verify that the "oldest" file was removed.
files = sorted(os.listdir(checkpointer.folder))
assert len(files) == checkpointer.max_history
assert files == names[1:] # i.e. [-max_history:]
def test_garbage_collect_no_collection(self, tmp_path: str) -> None:
"""Test `Checkpointer.garbage_collect` when collection is unneeded."""
# Set up a checkpointer with max_history=3 and 2 state files.
checkpointer = Checkpointer(folder=tmp_path, max_history=3)
names = sorted(create_state_files(tmp_path, "model", n_files=2))
checkpointer.garbage_collect("model_state")
# Verify that no files were removed.
files = sorted(os.listdir(checkpointer.folder))
assert files == names
def test_garbage_collect_infinite_history(self, tmp_path: str) -> None:
"""Test `Checkpointer.garbage_collect` when `max_history=None`."""
# Set up a checkpointer with max_history=None and 3 state files.
checkpointer = Checkpointer(folder=tmp_path, max_history=None)
names = sorted(create_state_files(tmp_path, "model", n_files=3))
checkpointer.garbage_collect("model_state")
# Verify that no files were removed.
files = sorted(os.listdir(checkpointer.folder))
assert files == names
def test_sort_matching_files(self, tmp_path: str) -> None:
"""Test `Checkpointer.sort_matching_files`."""
checkpointer = Checkpointer(folder=tmp_path)
names = sorted(create_state_files(tmp_path, "model", n_files=3))
create_state_files(tmp_path, "optimizer", n_files=2)
files = checkpointer.sort_matching_files("model_state")
assert names == files
@pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"])
@pytest.mark.parametrize("config", [True, False], ids=["config", "no_cfg"])
def test_save_model(
self, tmp_path: str, model: Model, config: bool, state: bool
) -> None:
"""Test `Checkpointer.save_model` with provided parameters."""
checkpointer = Checkpointer(folder=tmp_path)
timestamp = checkpointer.save_model(model, config, state)
# Verify config save file's existence.
cfg_path = os.path.join(checkpointer.folder, "model_config.json")
if config:
assert Path(cfg_path).is_file()
else:
assert not Path(cfg_path).is_file()
# Vertify weights save file's existence.
if state: # test state file save
assert isinstance(timestamp, str)
state_path = os.path.join(
checkpointer.folder, f"model_state_{timestamp}.json"
)
assert Path(state_path).is_file()
else:
assert timestamp is None
assert not checkpointer.sort_matching_files("model_state")
@pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"])
@pytest.mark.parametrize("config", [True, False], ids=["config", "no_cfg"])
def test_save_optimizer(
self, tmp_path: str, optimizer: Optimizer, config: bool, state: bool
) -> None:
"""Test `Checkpointer.save_optimizer` with provided parameters."""
checkpointer = Checkpointer(folder=tmp_path)
timestamp = checkpointer.save_optimizer(optimizer, config, state)
# Verify config save file's existence.
cfg_path = os.path.join(checkpointer.folder, "optimizer_config.json")
if config:
assert Path(cfg_path).is_file()
else:
assert not Path(cfg_path).is_file()
# Vertify state save file's existence.
if state:
assert isinstance(timestamp, str)
state_path = os.path.join(
checkpointer.folder, f"optimizer_state_{timestamp}.json"
)
assert Path(state_path).is_file()
else:
assert timestamp is None
assert not checkpointer.sort_matching_files("optimizer_state")
def test_save_metrics(self, tmp_path: str) -> None:
"""Test that `Checkpointer.save_metrics` works as expected.
This is a multi-part test rather than unit one, to verify
that the `append` parameter and its backend work properly.
"""
# Setup for this multi-part test.
metrics = {
"foo": 42.0,
"bar": np.array([0, 1]),
} # type: Dict[str, Union[float, np.ndarray]]
checkpointer = Checkpointer(tmp_path)
csv_path = os.path.join(tmp_path, "metrics.csv")
json_path = os.path.join(tmp_path, "metrics.json")
# Case 'append=True' but the files do not exist.
checkpointer.save_metrics(metrics, append=True, timestamp="0")
assert os.path.isfile(csv_path)
assert os.path.isfile(json_path)
scalars = pd.DataFrame({"timestamp": [0], "foo": [42.0]})
assert (pd.read_csv(csv_path) == scalars).all(axis=None)
m_json = {"foo": 42.0, "bar": [0, 1]}
assert json_load(json_path) == {"0": m_json}
# Case 'append=False', overwriting existing files.
checkpointer.save_metrics(metrics, append=False, timestamp="0")
assert (pd.read_csv(csv_path) == scalars).all(axis=None)
assert json_load(json_path) == {"0": m_json}
# Case 'append=True', appending to existing files.
checkpointer.save_metrics(metrics, append=True, timestamp="1")
scalars = pd.DataFrame({"timestamp": [0, 1], "foo": [42.0, 42.0]})
m_json = {"0": m_json, "1": m_json}
assert (pd.read_csv(csv_path) == scalars).all(axis=None)
assert json_load(json_path) == m_json
@pytest.mark.parametrize("first", [True, False], ids=["first", "notfirst"])
def test_checkpoint(
self, tmp_path: str, model: Model, optimizer: Optimizer, first: bool
) -> None:
"""Test that `Checkpointer.checkpoint` works as expected."""
# Set up a checkpointer and call its checkpoint method.
checkpointer = Checkpointer(tmp_path)
metrics = {"foo": 42.0, "bar": np.array([0, 1])}
if first: # create some files that should be removed on `first_call`
create_config_file(checkpointer, "model")
timestamp = checkpointer.checkpoint(
model=model,
optimizer=optimizer,
metrics=metrics, # type: ignore
first_call=first,
)
assert isinstance(timestamp, str)
# Verify whether config and metric files exist, as expected.
m_cfg = os.path.join(tmp_path, "model_config.json")
o_cfg = os.path.join(tmp_path, "optimizer_config.json")
if first:
assert os.path.isfile(m_cfg)
assert os.path.isfile(o_cfg)
else:
assert not os.path.isfile(m_cfg)
assert not os.path.isfile(o_cfg)
# Verify that state and metric files exist as expected.
path = os.path.join(tmp_path, f"model_state_{timestamp}.json")
assert os.path.isfile(path)
path = os.path.join(tmp_path, f"optimizer_state_{timestamp}.json")
assert os.path.isfile(path)
assert os.path.isfile(os.path.join(tmp_path, "metrics.csv"))
assert os.path.isfile(os.path.join(tmp_path, "metrics.json"))
@pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"])
@pytest.mark.parametrize("config", [True, False], ids=["config", "model"])
def test_load_model(
self, tmp_path: str, model: Model, config: bool, state: bool
) -> None:
"""Test `Checkpointer.load_model` with provided parameters."""
checkpointer = Checkpointer(tmp_path)
# Save the model (config + weights), then reload based on parameters.
timestamp = checkpointer.save_model(model, config=True, state=True)
with mock.patch.object(type(model), "set_weights") as p_set_weights:
loaded_model = checkpointer.load_model(
model=(None if config else model),
timestamp=(timestamp if config else None), # arbitrary swap
load_state=state,
)
# Verify that the loadd model is either the input one or similar.
if config:
assert isinstance(loaded_model, type(model))
assert loaded_model is not model
assert loaded_model.get_config() == model.get_config()
else:
assert loaded_model is model
# Verify that `set_weights` was called, with proper values.
if state:
p_set_weights.assert_called_once()
if config:
assert loaded_model.get_weights() == model.get_weights()
else:
p_set_weights.assert_not_called()
def test_load_model_fails(self, tmp_path: str, model: Model) -> None:
"""Test that `Checkpointer.load_model` raises excepted errors."""
checkpointer = Checkpointer(tmp_path)
# Case when the weights file is missing.
checkpointer.save_model(model, config=False, state=False)
with pytest.raises(FileNotFoundError):
checkpointer.load_model(model=model, load_state=True)
# Case when the config file is mising.
checkpointer.save_model(model, config=False, state=True)
with pytest.raises(FileNotFoundError):
checkpointer.load_model(model=None)
# Case when a wrong model input is provided.
with pytest.raises(TypeError):
checkpointer.load_model(model="wrong-type") # type: ignore
@pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"])
@pytest.mark.parametrize("config", [True, False], ids=["config", "optim"])
def test_load_optimizer(
self, tmp_path: str, optimizer: Optimizer, config: bool, state: bool
) -> None:
"""Test `Checkpointer.load_optimizer` with provided parameters."""
checkpointer = Checkpointer(tmp_path)
# Save the optimizer (config + state), then reload based on parameters.
stamp = checkpointer.save_optimizer(optimizer, config=True, state=True)
with mock.patch.object(Optimizer, "set_state") as p_set_state:
loaded_optim = checkpointer.load_optimizer(
optimizer=(None if config else optimizer),
timestamp=(stamp if config else None), # arbitrary swap
load_state=state,
)
# Verify that the loaded optimizer is either the input one or similar.
if config:
assert isinstance(loaded_optim, Optimizer)
assert loaded_optim is not optimizer
assert loaded_optim.get_config() == optimizer.get_config()
else:
assert loaded_optim is optimizer
# Verify that `set_state` was called, with proper values.
if state:
p_set_state.assert_called_once()
if config:
assert loaded_optim.get_state() == optimizer.get_state()
else:
p_set_state.assert_not_called()
def test_load_optimizer_fails(
self, tmp_path: str, optimizer: Optimizer
) -> None:
"""Test that `Checkpointer.load_optimizer` raises excepted errors."""
checkpointer = Checkpointer(tmp_path)
# Case when the state file is missing.
checkpointer.save_optimizer(optimizer, config=False, state=False)
with pytest.raises(FileNotFoundError):
checkpointer.load_optimizer(optimizer=optimizer, load_state=True)
# Case when the config file is mising.
checkpointer.save_optimizer(optimizer, config=False, state=True)
with pytest.raises(FileNotFoundError):
checkpointer.load_optimizer(optimizer=None)
# Case when a wrong optimizer input is provided.
with pytest.raises(TypeError):
checkpointer.load_optimizer(optimizer="wrong-type") # type: ignore
def test_load_metrics(self, tmp_path: str) -> None:
"""Test that `Checkpointer.load_metrics` works properly."""
# Setup things by saving a couple of sets of metrics.
metrics = {
"foo": 42.0,
"bar": np.array([0, 1]),
} # type: Dict[str, Union[float, np.ndarray]]
checkpointer = Checkpointer(tmp_path)
time_0 = checkpointer.save_metrics(metrics, append=False)
time_1 = checkpointer.save_metrics(metrics, append=True)
# Test reloading all checkpointed metrics.
reloaded = checkpointer.load_metrics(timestamp=None)
assert isinstance(reloaded, dict)
assert reloaded.keys() == {time_0, time_1}
for scores in reloaded.values():
assert isinstance(scores, dict)
assert scores.keys() == metrics.keys()
assert scores["foo"] == metrics["foo"]
assert (scores["bar"] == metrics["bar"]).all() # type: ignore
# Test reloading only metrics from one timestamp.
reloaded = checkpointer.load_metrics(timestamp=time_0)
assert isinstance(reloaded, dict)
assert reloaded.keys() == {time_0}
def test_load_scalar_metrics(self, tmp_path: str) -> None:
"""Test that `Checkpointer.load_scalar_metrics` works properly."""
# Setup things by saving a couple of sets of metrics.
metrics = {
"foo": 42.0,
"bar": np.array([0, 1]),
} # type: Dict[str, Union[float, np.ndarray]]
checkpointer = Checkpointer(tmp_path)
time_0 = checkpointer.save_metrics(metrics, append=False)
time_1 = checkpointer.save_metrics(metrics, append=True)
expect = pd.DataFrame(
{"foo": [42.0, 42.0], "timestamp": [time_0, time_1]}
).set_index("timestamp")
# Test reloading scalar metrics.
scores = checkpointer.load_scalar_metrics()
assert isinstance(scores, pd.DataFrame)
assert scores.index.names == expect.index.names
assert scores.columns == expect.columns
assert scores.shape == expect.shape
assert (scores == expect).all(axis=None)
...@@ -193,7 +193,7 @@ class DeclearnTestCase: ...@@ -193,7 +193,7 @@ class DeclearnTestCase:
netwk = self.build_netwk_server() netwk = self.build_netwk_server()
optim = self.build_optim_config() optim = self.build_optim_config()
with tempfile.TemporaryDirectory() as folder: with tempfile.TemporaryDirectory() as folder:
server = FederatedServer(model, netwk, optim, folder=folder) server = FederatedServer(model, netwk, optim, checkpoint=folder)
config = { config = {
"rounds": self.rounds, "rounds": self.rounds,
"register": {"max_clients": self.nb_clients, "timeout": 20}, "register": {"max_clients": self.nb_clients, "timeout": 20},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment