Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ddc04028 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Merge branch 'revise-checkpointer' into 'main'

Revise `Checkpointer`.

See merge request !21
parents 5f23f935 4cf8b549
No related branches found
No related tags found
1 merge request!21Revise `Checkpointer`.
Pipeline #749140 waiting for manual action
......@@ -182,7 +182,9 @@ optim = declearn.main.FLOptimConfig.from_params(
aggregator="averaging",
client_opt=0.001,
)
server = declearn.main.FederatedServer(model, netwk, optim, folder="outputs")
server = declearn.main.FederatedServer(
model, netwk, optim, checkpoint="outputs"
)
config = declearn.main.config.FLRunConfig.from_params(
rounds=10,
register={"min_clients": 1, "max_clients": 3, "timeout": 180},
......@@ -206,7 +208,7 @@ train = declearn.dataset.InMemoryDataset(
expose_classes=True # enable sharing of unique target values
)
valid = declearn.dataset.InMemoryDataset("path/to/valid.csv", target="label")
client = declearn.main.FederatedClient(netwk, train, valid, folder="outputs")
client = declearn.main.FederatedClient(netwk, train, valid, checkpoint="outputs")
client.run()
```
......@@ -249,9 +251,10 @@ exposed here.
- decide whether to continue, based on the number of
rounds taken or on the evolution of the global loss
- Finally:
- restore the model weights that yielded the lowest global loss
- notify clients that training is over, so they can disconnect
and run their final routine (e.g. model saving)
- optionally save the model (through a checkpointer)
and run their final routine (e.g. save the "best" model)
- optionally checkpoint the "best" model
- close the network server and end the process
#### Detail of the process phases
......@@ -319,14 +322,15 @@ exposed here.
- update model weights
- perform evaluation steps based on effort constraints
- step: update evaluation metrics, including the model's loss, over a batch
- checkpoint the model, then send results to the server
- optionally prevent sharing detailed metrics with the server; always
include the scalar validation loss value
- optionally checkpoint the model, local optimizer and evaluation metrics
- send results to the server: optionally prevent sharing detailed metrics;
always include the scalar validation loss value
- messaging: (EvaluateRequest <-> EvaluateReply)
- Server:
- aggregate local loss values into a global loss metric
- aggregate all other evaluation metrics and log their values
- checkpoint the model and the global loss
- optionally checkpoint the model, optimizer, aggregated evaluation
metrics and client-wise ones
### Overview of the declearn API
......
......@@ -5,15 +5,13 @@
import asyncio
import dataclasses
import logging
import os
from typing import Any, Dict, Optional, Union
from declearn.communication import NetworkClientConfig, messaging
from declearn.communication.api import NetworkClient
from declearn.dataset import Dataset, load_dataset_from_json
from declearn.main.utils import Checkpointer, TrainingManager
from declearn.utils import get_logger, json_dump
from declearn.utils import get_logger
__all__ = [
......@@ -31,7 +29,7 @@ class FederatedClient:
netwk: Union[NetworkClient, NetworkClientConfig, Dict[str, Any], str],
train_data: Union[Dataset, str],
valid_data: Optional[Union[Dataset, str]] = None,
folder: Optional[str] = None,
checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None,
share_metrics: bool = True,
logger: Union[logging.Logger, str, None] = None,
) -> None:
......@@ -51,11 +49,11 @@ class FederatedClient:
Optional Dataset instance wrapping validation data, or
path to a JSON file from which it can be instantiated.
If None, run evaluation rounds over `train_data`.
folder: str or None, default=None
Optional folder where to write out a model dump, round-
wise weights checkpoints and local validation losses.
If None, only record the loss metric and lowest-loss-
yielding weights in memory (under `self.checkpoint`).
checkpoint: Checkpointer or dict or str or None, default=None
Optional Checkpointer instance or instantiation dict to be
used so as to save round-wise model, optimizer and metrics.
If a single string is provided, treat it as the checkpoint
folder path and use default values for other parameters.
share_metrics: bool, default=True
Whether to share evaluation metrics with the server,
or save them locally and only send the model's loss.
......@@ -102,9 +100,10 @@ class FederatedClient:
if not (valid_data is None or isinstance(valid_data, Dataset)):
raise TypeError("'valid_data' should be a Dataset or path to one.")
self.valid_data = valid_data
# Record the checkpointing folder and create a Checkpointer slot.
self.folder = folder
self.checkpointer = None # type: Optional[Checkpointer]
# Assign an optional checkpointer.
if checkpoint is not None:
checkpoint = Checkpointer.from_specs(checkpoint)
self.ckptr = checkpoint
# Record the metric-sharing boolean switch.
self.share_metrics = bool(share_metrics)
# Create a TrainingManager slot, populated at initialization phase.
......@@ -249,13 +248,16 @@ class FederatedClient:
metrics=message.metrics,
logger=self.logger,
)
# Instantiate a checkpointer and save the initial model.
self.checkpointer = Checkpointer(message.model, self.folder)
self.checkpointer.save_model()
self.checkpointer.checkpoint(float("inf")) # initial weights
# If instructed to do so, await a PrivacyRequest to set up DP-SGD.
if message.dpsgd:
await self._initialize_dpsgd()
# Optionally checkpoint the received model and optimizer.
if self.ckptr:
self.ckptr.checkpoint(
model=self.trainmanager.model,
optimizer=self.trainmanager.optim,
first_call=True,
)
async def _initialize_dpsgd(
self,
......@@ -309,6 +311,7 @@ class FederatedClient:
# lazy-import the DPTrainingManager, that involves some optional,
# heavy-loadtime dependencies; pylint: disable=import-outside-toplevel
from declearn.main.privacy import DPTrainingManager
# pylint: enable=import-outside-toplevel
self.trainmanager = DPTrainingManager(
model=self.trainmanager.model,
......@@ -368,9 +371,13 @@ class FederatedClient:
reply = self.trainmanager.evaluation_round(message)
# Post-process the results.
if isinstance(reply, messaging.EvaluationReply): # not an Error
# Checkpoint the model and record the local loss.
if self.checkpointer is not None: # True in `run` context
self.checkpointer.checkpoint(reply.loss)
# Optionnally checkpoint the model, optimizer and local loss.
if self.ckptr:
self.ckptr.checkpoint(
model=self.trainmanager.model,
optimizer=self.trainmanager.optim,
metrics=self.trainmanager.metrics.get_result(),
)
# Optionally prevent sharing metrics (save for the loss).
if not self.share_metrics:
reply.metrics.clear()
......@@ -393,17 +400,12 @@ class FederatedClient:
message.rounds,
message.loss,
)
if self.folder is not None:
# Save the locally-best-performing model weights.
if self.checkpointer is not None: # True in `run` context
path = os.path.join(self.folder, "best_local_weights.json")
self.logger.info("Saving best local weights in '%s'.", path)
self.checkpointer.reset_best_weights()
json_dump(self.checkpointer.model.get_weights(), path)
# Save the globally-best-performing model weights.
path = os.path.join(self.folder, "final_weights.json")
self.logger.info("Saving final weights in '%s'.", path)
json_dump(message.weights, path)
if self.ckptr:
path = f"{self.ckptr.folder}/model_state_best.json"
self.logger.info("Checkpointing final weights under %s.", path)
assert self.trainmanager is not None # for mypy
self.trainmanager.model.set_weights(message.weights)
self.ckptr.save_model(self.trainmanager.model, timestamp="best")
async def cancel_training(
self,
......
......@@ -24,7 +24,7 @@ from declearn.main.utils import (
aggregate_clients_data_info,
)
from declearn.metrics import MetricInputType, MetricSet
from declearn.model.api import Model
from declearn.model.api import Model, Vector
from declearn.utils import deserialize_object, get_logger
......@@ -47,7 +47,7 @@ class FederatedServer:
netwk: Union[NetworkServer, NetworkServerConfig, Dict[str, Any], str],
optim: Union[FLOptimConfig, str, Dict[str, Any]],
metrics: Union[MetricSet, List[MetricInputType], None] = None,
folder: Optional[str] = None,
checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None,
logger: Union[logging.Logger, str, None] = None,
) -> None:
"""Instantiate the orchestrating server for a federated learning task.
......@@ -72,11 +72,11 @@ class FederatedServer:
to wrap into one, defining evaluation metrics to compute in
addition to the model's loss.
If None, only compute and report the model's loss.
folder: str or None, default=None
Optional folder where to write out a model dump, round-
wise weights checkpoints and global validation losses.
If None, only record the loss metric and lowest-loss-
yielding weights in memory (under `self.checkpoint`).
checkpoint: Checkpointer or dict or str or None, default=None
Optional Checkpointer instance or instantiation dict to be
used so as to save round-wise model, optimizer and metrics.
If a single string is provided, treat it as the checkpoint
folder path and use default values for other parameters.
logger: logging.Logger or str or None, default=None,
Logger to use, or name of a logger to set up with
`declearn.utils.get_logger`. If None, use `type(self)`.
......@@ -125,8 +125,13 @@ class FederatedServer:
self.c_opt = optim.client_opt
# Assign the wrapped MetricSet.
self.metrics = MetricSet.from_specs(metrics)
# Assign a model checkpointer.
self.checkpointer = Checkpointer(self.model, folder)
# Assign an optional checkpointer.
if checkpoint is not None:
checkpoint = Checkpointer.from_specs(checkpoint)
self.ckptr = checkpoint
# Set up private attributes to record the loss values and best weights.
self._loss = {} # type: Dict[int, float]
self._best = None # type: Optional[Vector]
def run(
self,
......@@ -177,8 +182,8 @@ class FederatedServer:
async with self.netwk:
# Conduct the initialization phase.
await self.initialization(config)
self.checkpointer.save_model()
self.checkpointer.checkpoint(float("inf")) # save initial weights
if self.ckptr:
self.ckptr.checkpoint(self.model, self.optim, first_call=True)
# Iteratively run training and evaluation rounds.
round_i = 0
while True:
......@@ -478,6 +483,7 @@ class FederatedServer:
EvaluateConfig dataclass instance wrapping data-batching
and computational effort constraints hyper-parameters.
"""
# Send evaluation requests and collect clients' replies.
self.logger.info("Initiating evaluation round %s", round_i)
clients = self._select_evaluation_round_participants()
await self._send_evaluation_instructions(clients, round_i, valid_cfg)
......@@ -485,12 +491,22 @@ class FederatedServer:
results = await self._collect_results(
clients, messaging.EvaluationReply, "evaluation"
)
# Compute and report aggregated evaluation metrics.
self.logger.info("Aggregating evaluation results.")
loss, metrics = self._aggregate_evaluation_results(results)
self.logger.info("Global loss is: %s", loss)
self.logger.info("Averaged loss is: %s", loss)
if metrics:
self.logger.info("Other global metrics are: %s", metrics)
self.checkpointer.checkpoint(loss)
self.logger.info(
"Other averaged scalar metrics are: %s",
{k: v for k, v in metrics.items() if isinstance(v, float)},
)
# Optionally checkpoint the model, optimizer and metrics.
if self.ckptr:
self._checkpoint_after_evaluation(metrics, results)
# Record the global loss, and update the kept "best" weights.
self._loss[round_i] = loss
if loss == min(self._loss.values()):
self._best = self.model.get_weights()
def _select_evaluation_round_participants(
self,
......@@ -551,6 +567,7 @@ class FederatedServer:
# Case when the client reported some metrics.
if reply.metrics:
states = reply.metrics.copy()
# Update the global metrics based on the local ones.
s_loss = states.pop("loss")
loss += s_loss["current"] # type: ignore
dvsr += s_loss["divisor"] # type: ignore
......@@ -567,6 +584,50 @@ class FederatedServer:
loss = loss / dvsr
return loss, metrics
def _checkpoint_after_evaluation(
self,
metrics: Dict[str, Union[float, np.ndarray]],
results: Dict[str, messaging.EvaluationReply],
) -> None:
"""Checkpoint the current model, optimizer and evaluation metrics.
This method is meant to be called at the end of an evaluation round.
Parameters
----------
metrics: dict[str, (float|np.ndarray)]
Aggregated evaluation metrics to checkpoint.
results: dict[str, EvaluationReply]
Client-wise EvaluationReply messages, based on which
`metrics` were already computed.
"""
# This method only works when a checkpointer is used.
if self.ckptr is None:
raise RuntimeError(
"`_checkpoint_after_evaluation` was called without "
"the FederatedServer having a Checkpointer."
)
# Checkpoint the model, optimizer and global evaluation metrics.
timestamp = self.ckptr.checkpoint(
model=self.model, optimizer=self.optim, metrics=metrics
)
# Checkpoint the client-wise metrics (or at least their loss).
# Use the same timestamp label as for global metrics and states.
local = MetricSet.from_config(self.metrics.get_config())
for client, reply in results.items():
if reply.metrics:
local.reset()
local.agg_states(reply.metrics)
metrics = local.get_result()
else:
metrics = {"loss": reply.loss}
self.ckptr.save_metrics(
metrics=local.get_result(),
prefix=f"metrics_{client}",
append=bool(self._loss),
timestamp=timestamp,
)
def _keep_training(
self,
round_i: int,
......@@ -589,7 +650,7 @@ class FederatedServer:
self.logger.info("Maximum number of training rounds reached.")
return False
if early_stop is not None:
early_stop.update(self.checkpointer.get_loss(round_i))
early_stop.update(self._loss[round_i])
if not early_stop.keep_training:
self.logger.info("Early stopping criterion reached.")
return False
......@@ -606,11 +667,16 @@ class FederatedServer:
rounds: int
Number of training rounds taken until now.
"""
self.checkpointer.reset_best_weights()
self.logger.info("Recovering weights that yielded the lowest loss.")
message = messaging.StopTraining(
weights=self.model.get_weights(),
loss=min(self.checkpointer.get_loss(i) for i in range(rounds)),
weights=self._best or self.model.get_weights(),
loss=min(self._loss.values()) if self._loss else float("nan"),
rounds=rounds,
)
self.logger.info("Notifying clients that training is over.")
await self.netwk.broadcast_message(message)
if self.ckptr:
path = f"{self.ckptr.folder}/model_state_best.json"
self.logger.info("Checkpointing final weights under %s.", path)
self.model.set_weights(message.weights)
self.ckptr.save_model(self.model, timestamp="best")
This diff is collapsed.
......@@ -327,7 +327,10 @@ class TrainingManager:
effort = constraints.get_values()
result = self.metrics.get_result()
states = self.metrics.get_states()
self.logger.info("Local evaluation metrics: %s", result)
self.logger.info(
"Local scalar evaluation metrics: %s",
{k: v for k, v in result.items() if isinstance(v, float)},
)
# Pack the result and computational effort information into a message.
self.logger.info("Packing local results to be sent to the server.")
return messaging.EvaluationReply(
......
......@@ -68,7 +68,8 @@ def run_client(
# (5) Instantiate a FederatedClient and run it.
client = FederatedClient(
network, train, valid, folder=f"{FILEDIR}/results/{name}"
# fmt: off
network, train, valid, checkpoint=f"{FILEDIR}/results/{name}"
# Note: you may add `share_metrics=False` to prevent sending
# evaluation metrics to the server, out of privacy concerns
)
......
......@@ -2,7 +2,7 @@
import argparse
import os
from typing import List
from typing import Collection
import pandas as pd
......@@ -29,14 +29,14 @@ DATADIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
def get_data(
dir: str = DATADIR,
names: List[str] = NAMES,
folder: str = DATADIR,
names: Collection[str] = NAMES,
) -> None:
"""Download and process the UCI heart disease dataset.
Arguments
---------
dir: str
folder: str
Path to the folder where to write output csv files.
names: list[str]
Names of centers, the dataset from which to download,
......@@ -61,8 +61,8 @@ def get_data(
# Binarize the target variable.
df["num"] = (df["num"] > 0).astype(int)
# Export the resulting dataset to a csv file.
os.makedirs(dir, exist_ok=True)
df.to_csv(f"{dir}/{name}.csv", index=False)
os.makedirs(folder, exist_ok=True)
df.to_csv(f"{folder}/{name}.csv", index=False)
# Code executed when the script is called directly.
......@@ -70,7 +70,7 @@ if __name__ == "__main__":
# Parse commandline parameters.
parser = argparse.ArgumentParser()
parser.add_argument(
"--dir",
"--folder",
type=str,
default=DATADIR,
help="folder where to write output csv files",
......@@ -84,4 +84,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# Download and pre-process the selected dataset(s).
get_data(dir=args.dir, names=args.names)
get_data(folder=args.folder, names=args.names)
......@@ -5,10 +5,9 @@ import os
from declearn.communication import NetworkServerConfig
from declearn.main import FederatedServer
from declearn.main.config import FLRunConfig, FLOptimConfig
from declearn.main.config import FLOptimConfig, FLRunConfig
from declearn.model.sklearn import SklearnSGDModel
FILEDIR = os.path.dirname(os.path.abspath(__file__))
......@@ -85,7 +84,10 @@ def run_server(
# f1-score and roc auc (with plot-enabling fpr/tpr curves) during
# evaluation rounds.
server = FederatedServer(
model, network, optim, metrics=["binary-classif", "binary-roc"]
# fmt: off
model, network, optim,
metrics=["binary-classif", "binary-roc"],
checkpoint=f"{FILEDIR}/results/server"
)
# Here, we set up 20 rounds of training, with 30 samples per batch
......
# coding: utf-8
"""Unit tests for Checkpointer class."""
import json
import os
from pathlib import Path
from typing import Dict, Iterator, List, Union
from unittest import mock
import numpy as np
import pandas as pd
import pytest
from sklearn.linear_model import SGDClassifier
from declearn.main.utils import Checkpointer
from declearn.model.api import Model
from declearn.model.sklearn import SklearnSGDModel
from declearn.optimizer import Optimizer
from declearn.utils import json_load
# Fixtures and utils
@pytest.fixture(name="checkpointer")
def fixture_checkpointer(tmp_path) -> Iterator[Checkpointer]:
"""Create a checkpointer within a temp dir"""
yield Checkpointer(tmp_path, 2)
@pytest.fixture(name="model")
def fixture_model() -> SklearnSGDModel:
"""Crete a toy binary-classification model."""
model = SklearnSGDModel(SGDClassifier())
model.initialize({"n_features": 8, "classes": np.arange(2)})
return model
@pytest.fixture(name="optimizer")
def fixture_optimizer() -> Optimizer:
"""Create a toy optimizer"""
testopt = Optimizer(lrate=1.0, modules=[("momentum", {"beta": 0.95})])
return testopt
@pytest.fixture(name="metrics")
def fixture_metrics() -> Dict[str, float]:
"""Create a metrics fixture"""
return {"loss": 0.5}
def create_state_files(folder: str, type_obj: str, n_files: int) -> List[str]:
"""Create test state files in checkpointer.ckpt"""
files = [
f"{type_obj}_state_23-01-{21 + idx}_15-45-35.json"
for idx in range(n_files)
]
for name in files:
with open(os.path.join(folder, name), "w", encoding="utf-8") as file:
json.dump({"test": "state"}, file)
return files
def create_config_file(checkpointer: Checkpointer, type_obj: str) -> str:
"""Create test cfg files in checkpointer.ckpt"""
path = os.path.join(checkpointer.folder, f"{type_obj}_config.json")
with open(path, "w", encoding="utf-8") as file:
json.dump({"test": "config"}, file)
return f"{type_obj}_config.json"
# Actual tests
class TestCheckpointer:
"""Unit tests for Checkpointer class"""
def test_init_default(self, tmp_path: str) -> None:
"""Test `Checkpointer.__init__` with `max_history=None`."""
checkpointer = Checkpointer(folder=tmp_path, max_history=None)
assert checkpointer.folder == tmp_path
assert Path(checkpointer.folder).is_dir()
assert checkpointer.max_history is None
def test_init_max_history(self, tmp_path: str) -> None:
"""Test `Checkpointer.__init__` with `max_history=2`."""
checkpointer = Checkpointer(folder=tmp_path, max_history=2)
assert checkpointer.folder == tmp_path
assert Path(checkpointer.folder).is_dir()
assert checkpointer.max_history == 2
def test_init_fails(self, tmp_path: str) -> None:
"""Test `Checkpointer.__init__` raises on negative `max_history`."""
with pytest.raises(TypeError):
Checkpointer(folder=tmp_path, max_history=-1)
def test_from_specs(self, tmp_path: str) -> None:
"""Test that `Checkpointer.from_specs` works properly.
This test is multi-part rather than unitary as the method
is merely boilerplate code refactored into a classmethod.
"""
tmp_path = str(tmp_path) # note: PosixPath
specs_list = [
tmp_path,
{"folder": tmp_path, "max_history": None},
Checkpointer(tmp_path),
]
# Iteratively test the various types of acceptable specs.
for specs in specs_list:
ckpt = Checkpointer.from_specs(specs) # type: ignore
assert isinstance(ckpt, Checkpointer)
assert ckpt.folder == tmp_path
assert ckpt.max_history is None
# Also test that the documented TypeError is raised.
with pytest.raises(TypeError):
Checkpointer.from_specs(0) # type: ignore
def test_garbage_collect(self, tmp_path: str) -> None:
"""Test `Checkpointer.garbage_collect` when collection is needed."""
# Set up a checkpointer with max_history=2 and 3 state files.
checkpointer = Checkpointer(folder=tmp_path, max_history=2)
names = sorted(create_state_files(tmp_path, "model", n_files=3))
checkpointer.garbage_collect("model_state")
# Verify that the "oldest" file was removed.
files = sorted(os.listdir(checkpointer.folder))
assert len(files) == checkpointer.max_history
assert files == names[1:] # i.e. [-max_history:]
def test_garbage_collect_no_collection(self, tmp_path: str) -> None:
"""Test `Checkpointer.garbage_collect` when collection is unneeded."""
# Set up a checkpointer with max_history=3 and 2 state files.
checkpointer = Checkpointer(folder=tmp_path, max_history=3)
names = sorted(create_state_files(tmp_path, "model", n_files=2))
checkpointer.garbage_collect("model_state")
# Verify that no files were removed.
files = sorted(os.listdir(checkpointer.folder))
assert files == names
def test_garbage_collect_infinite_history(self, tmp_path: str) -> None:
"""Test `Checkpointer.garbage_collect` when `max_history=None`."""
# Set up a checkpointer with max_history=None and 3 state files.
checkpointer = Checkpointer(folder=tmp_path, max_history=None)
names = sorted(create_state_files(tmp_path, "model", n_files=3))
checkpointer.garbage_collect("model_state")
# Verify that no files were removed.
files = sorted(os.listdir(checkpointer.folder))
assert files == names
def test_sort_matching_files(self, tmp_path: str) -> None:
"""Test `Checkpointer.sort_matching_files`."""
checkpointer = Checkpointer(folder=tmp_path)
names = sorted(create_state_files(tmp_path, "model", n_files=3))
create_state_files(tmp_path, "optimizer", n_files=2)
files = checkpointer.sort_matching_files("model_state")
assert names == files
@pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"])
@pytest.mark.parametrize("config", [True, False], ids=["config", "no_cfg"])
def test_save_model(
self, tmp_path: str, model: Model, config: bool, state: bool
) -> None:
"""Test `Checkpointer.save_model` with provided parameters."""
checkpointer = Checkpointer(folder=tmp_path)
timestamp = checkpointer.save_model(model, config, state)
# Verify config save file's existence.
cfg_path = os.path.join(checkpointer.folder, "model_config.json")
if config:
assert Path(cfg_path).is_file()
else:
assert not Path(cfg_path).is_file()
# Vertify weights save file's existence.
if state: # test state file save
assert isinstance(timestamp, str)
state_path = os.path.join(
checkpointer.folder, f"model_state_{timestamp}.json"
)
assert Path(state_path).is_file()
else:
assert timestamp is None
assert not checkpointer.sort_matching_files("model_state")
@pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"])
@pytest.mark.parametrize("config", [True, False], ids=["config", "no_cfg"])
def test_save_optimizer(
self, tmp_path: str, optimizer: Optimizer, config: bool, state: bool
) -> None:
"""Test `Checkpointer.save_optimizer` with provided parameters."""
checkpointer = Checkpointer(folder=tmp_path)
timestamp = checkpointer.save_optimizer(optimizer, config, state)
# Verify config save file's existence.
cfg_path = os.path.join(checkpointer.folder, "optimizer_config.json")
if config:
assert Path(cfg_path).is_file()
else:
assert not Path(cfg_path).is_file()
# Vertify state save file's existence.
if state:
assert isinstance(timestamp, str)
state_path = os.path.join(
checkpointer.folder, f"optimizer_state_{timestamp}.json"
)
assert Path(state_path).is_file()
else:
assert timestamp is None
assert not checkpointer.sort_matching_files("optimizer_state")
def test_save_metrics(self, tmp_path: str) -> None:
"""Test that `Checkpointer.save_metrics` works as expected.
This is a multi-part test rather than unit one, to verify
that the `append` parameter and its backend work properly.
"""
# Setup for this multi-part test.
metrics = {
"foo": 42.0,
"bar": np.array([0, 1]),
} # type: Dict[str, Union[float, np.ndarray]]
checkpointer = Checkpointer(tmp_path)
csv_path = os.path.join(tmp_path, "metrics.csv")
json_path = os.path.join(tmp_path, "metrics.json")
# Case 'append=True' but the files do not exist.
checkpointer.save_metrics(metrics, append=True, timestamp="0")
assert os.path.isfile(csv_path)
assert os.path.isfile(json_path)
scalars = pd.DataFrame({"timestamp": [0], "foo": [42.0]})
assert (pd.read_csv(csv_path) == scalars).all(axis=None)
m_json = {"foo": 42.0, "bar": [0, 1]}
assert json_load(json_path) == {"0": m_json}
# Case 'append=False', overwriting existing files.
checkpointer.save_metrics(metrics, append=False, timestamp="0")
assert (pd.read_csv(csv_path) == scalars).all(axis=None)
assert json_load(json_path) == {"0": m_json}
# Case 'append=True', appending to existing files.
checkpointer.save_metrics(metrics, append=True, timestamp="1")
scalars = pd.DataFrame({"timestamp": [0, 1], "foo": [42.0, 42.0]})
m_json = {"0": m_json, "1": m_json}
assert (pd.read_csv(csv_path) == scalars).all(axis=None)
assert json_load(json_path) == m_json
@pytest.mark.parametrize("first", [True, False], ids=["first", "notfirst"])
def test_checkpoint(
self, tmp_path: str, model: Model, optimizer: Optimizer, first: bool
) -> None:
"""Test that `Checkpointer.checkpoint` works as expected."""
# Set up a checkpointer and call its checkpoint method.
checkpointer = Checkpointer(tmp_path)
metrics = {"foo": 42.0, "bar": np.array([0, 1])}
if first: # create some files that should be removed on `first_call`
create_config_file(checkpointer, "model")
timestamp = checkpointer.checkpoint(
model=model,
optimizer=optimizer,
metrics=metrics, # type: ignore
first_call=first,
)
assert isinstance(timestamp, str)
# Verify whether config and metric files exist, as expected.
m_cfg = os.path.join(tmp_path, "model_config.json")
o_cfg = os.path.join(tmp_path, "optimizer_config.json")
if first:
assert os.path.isfile(m_cfg)
assert os.path.isfile(o_cfg)
else:
assert not os.path.isfile(m_cfg)
assert not os.path.isfile(o_cfg)
# Verify that state and metric files exist as expected.
path = os.path.join(tmp_path, f"model_state_{timestamp}.json")
assert os.path.isfile(path)
path = os.path.join(tmp_path, f"optimizer_state_{timestamp}.json")
assert os.path.isfile(path)
assert os.path.isfile(os.path.join(tmp_path, "metrics.csv"))
assert os.path.isfile(os.path.join(tmp_path, "metrics.json"))
@pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"])
@pytest.mark.parametrize("config", [True, False], ids=["config", "model"])
def test_load_model(
self, tmp_path: str, model: Model, config: bool, state: bool
) -> None:
"""Test `Checkpointer.load_model` with provided parameters."""
checkpointer = Checkpointer(tmp_path)
# Save the model (config + weights), then reload based on parameters.
timestamp = checkpointer.save_model(model, config=True, state=True)
with mock.patch.object(type(model), "set_weights") as p_set_weights:
loaded_model = checkpointer.load_model(
model=(None if config else model),
timestamp=(timestamp if config else None), # arbitrary swap
load_state=state,
)
# Verify that the loadd model is either the input one or similar.
if config:
assert isinstance(loaded_model, type(model))
assert loaded_model is not model
assert loaded_model.get_config() == model.get_config()
else:
assert loaded_model is model
# Verify that `set_weights` was called, with proper values.
if state:
p_set_weights.assert_called_once()
if config:
assert loaded_model.get_weights() == model.get_weights()
else:
p_set_weights.assert_not_called()
def test_load_model_fails(self, tmp_path: str, model: Model) -> None:
"""Test that `Checkpointer.load_model` raises excepted errors."""
checkpointer = Checkpointer(tmp_path)
# Case when the weights file is missing.
checkpointer.save_model(model, config=False, state=False)
with pytest.raises(FileNotFoundError):
checkpointer.load_model(model=model, load_state=True)
# Case when the config file is mising.
checkpointer.save_model(model, config=False, state=True)
with pytest.raises(FileNotFoundError):
checkpointer.load_model(model=None)
# Case when a wrong model input is provided.
with pytest.raises(TypeError):
checkpointer.load_model(model="wrong-type") # type: ignore
@pytest.mark.parametrize("state", [True, False], ids=["state", "no_state"])
@pytest.mark.parametrize("config", [True, False], ids=["config", "optim"])
def test_load_optimizer(
self, tmp_path: str, optimizer: Optimizer, config: bool, state: bool
) -> None:
"""Test `Checkpointer.load_optimizer` with provided parameters."""
checkpointer = Checkpointer(tmp_path)
# Save the optimizer (config + state), then reload based on parameters.
stamp = checkpointer.save_optimizer(optimizer, config=True, state=True)
with mock.patch.object(Optimizer, "set_state") as p_set_state:
loaded_optim = checkpointer.load_optimizer(
optimizer=(None if config else optimizer),
timestamp=(stamp if config else None), # arbitrary swap
load_state=state,
)
# Verify that the loaded optimizer is either the input one or similar.
if config:
assert isinstance(loaded_optim, Optimizer)
assert loaded_optim is not optimizer
assert loaded_optim.get_config() == optimizer.get_config()
else:
assert loaded_optim is optimizer
# Verify that `set_state` was called, with proper values.
if state:
p_set_state.assert_called_once()
if config:
assert loaded_optim.get_state() == optimizer.get_state()
else:
p_set_state.assert_not_called()
def test_load_optimizer_fails(
self, tmp_path: str, optimizer: Optimizer
) -> None:
"""Test that `Checkpointer.load_optimizer` raises excepted errors."""
checkpointer = Checkpointer(tmp_path)
# Case when the state file is missing.
checkpointer.save_optimizer(optimizer, config=False, state=False)
with pytest.raises(FileNotFoundError):
checkpointer.load_optimizer(optimizer=optimizer, load_state=True)
# Case when the config file is mising.
checkpointer.save_optimizer(optimizer, config=False, state=True)
with pytest.raises(FileNotFoundError):
checkpointer.load_optimizer(optimizer=None)
# Case when a wrong optimizer input is provided.
with pytest.raises(TypeError):
checkpointer.load_optimizer(optimizer="wrong-type") # type: ignore
def test_load_metrics(self, tmp_path: str) -> None:
"""Test that `Checkpointer.load_metrics` works properly."""
# Setup things by saving a couple of sets of metrics.
metrics = {
"foo": 42.0,
"bar": np.array([0, 1]),
} # type: Dict[str, Union[float, np.ndarray]]
checkpointer = Checkpointer(tmp_path)
time_0 = checkpointer.save_metrics(metrics, append=False)
time_1 = checkpointer.save_metrics(metrics, append=True)
# Test reloading all checkpointed metrics.
reloaded = checkpointer.load_metrics(timestamp=None)
assert isinstance(reloaded, dict)
assert reloaded.keys() == {time_0, time_1}
for scores in reloaded.values():
assert isinstance(scores, dict)
assert scores.keys() == metrics.keys()
assert scores["foo"] == metrics["foo"]
assert (scores["bar"] == metrics["bar"]).all() # type: ignore
# Test reloading only metrics from one timestamp.
reloaded = checkpointer.load_metrics(timestamp=time_0)
assert isinstance(reloaded, dict)
assert reloaded.keys() == {time_0}
def test_load_scalar_metrics(self, tmp_path: str) -> None:
"""Test that `Checkpointer.load_scalar_metrics` works properly."""
# Setup things by saving a couple of sets of metrics.
metrics = {
"foo": 42.0,
"bar": np.array([0, 1]),
} # type: Dict[str, Union[float, np.ndarray]]
checkpointer = Checkpointer(tmp_path)
time_0 = checkpointer.save_metrics(metrics, append=False)
time_1 = checkpointer.save_metrics(metrics, append=True)
expect = pd.DataFrame(
{"foo": [42.0, 42.0], "timestamp": [time_0, time_1]}
).set_index("timestamp")
# Test reloading scalar metrics.
scores = checkpointer.load_scalar_metrics()
assert isinstance(scores, pd.DataFrame)
assert scores.index.names == expect.index.names
assert scores.columns == expect.columns
assert scores.shape == expect.shape
assert (scores == expect).all(axis=None)
......@@ -193,7 +193,7 @@ class DeclearnTestCase:
netwk = self.build_netwk_server()
optim = self.build_optim_config()
with tempfile.TemporaryDirectory() as folder:
server = FederatedServer(model, netwk, optim, folder=folder)
server = FederatedServer(model, netwk, optim, checkpoint=folder)
config = {
"rounds": self.rounds,
"register": {"max_clients": self.nb_clients, "timeout": 20},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment