Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit f398148d authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Disable GPU use in quickrun mode.

parent 17d11a02
No related branches found
No related tags found
1 merge request!41Quickrun mode
...@@ -53,7 +53,12 @@ from declearn.quickrun._config import ( ...@@ -53,7 +53,12 @@ from declearn.quickrun._config import (
from declearn.quickrun._parser import parse_data_folder from declearn.quickrun._parser import parse_data_folder
from declearn.quickrun._split_data import split_data from declearn.quickrun._split_data import split_data
from declearn.test_utils import make_importable from declearn.test_utils import make_importable
from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger, run_as_processes from declearn.utils import (
LOGGING_LEVEL_MAJOR,
get_logger,
run_as_processes,
set_device_policy,
)
__all__ = ["quickrun"] __all__ = ["quickrun"]
...@@ -94,6 +99,7 @@ def run_server( ...@@ -94,6 +99,7 @@ def run_server(
expe_config: ExperimentConfig, expe_config: ExperimentConfig,
) -> None: ) -> None:
"""Routine to run a FL server, called by `run_declearn_experiment`.""" """Routine to run a FL server, called by `run_declearn_experiment`."""
set_device_policy(gpu=False)
model = get_model(folder, model_config) model = get_model(folder, model_config)
checkpoint = get_checkpoint(folder, expe_config) checkpoint = get_checkpoint(folder, expe_config)
checkpoint = os.path.join(checkpoint, "server") checkpoint = os.path.join(checkpoint, "server")
...@@ -113,11 +119,12 @@ def run_client( ...@@ -113,11 +119,12 @@ def run_client(
paths: dict, paths: dict,
) -> None: ) -> None:
"""Routine to run a FL client, called by `run_declearn_experiment`.""" """Routine to run a FL client, called by `run_declearn_experiment`."""
# Overwrite client name based on folder name # Overwrite client name based on folder name.
network.name = name network.name = name
# Make the model importable # Make the model importable and disable GPU use.
set_device_policy(gpu=False)
_ = get_model(folder, model_config) _ = get_model(folder, model_config)
# Add checkpointer # Add checkpointer.
checkpoint = get_checkpoint(folder, expe_config) checkpoint = get_checkpoint(folder, expe_config)
checkpoint = os.path.join(checkpoint, name) checkpoint = os.path.join(checkpoint, name)
# Set up a logger: write everything to file, but filter console outputs. # Set up a logger: write everything to file, but filter console outputs.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment