From f398148d6c29cf544a3c562fea7531eb339e8599 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 31 Mar 2023 16:16:58 +0200 Subject: [PATCH] Disable GPU use in quickrun mode. --- declearn/quickrun/run.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py index ae51b5f2..84426f53 100644 --- a/declearn/quickrun/run.py +++ b/declearn/quickrun/run.py @@ -53,7 +53,12 @@ from declearn.quickrun._config import ( from declearn.quickrun._parser import parse_data_folder from declearn.quickrun._split_data import split_data from declearn.test_utils import make_importable -from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger, run_as_processes +from declearn.utils import ( + LOGGING_LEVEL_MAJOR, + get_logger, + run_as_processes, + set_device_policy, +) __all__ = ["quickrun"] @@ -94,6 +99,7 @@ def run_server( expe_config: ExperimentConfig, ) -> None: """Routine to run a FL server, called by `run_declearn_experiment`.""" + set_device_policy(gpu=False) model = get_model(folder, model_config) checkpoint = get_checkpoint(folder, expe_config) checkpoint = os.path.join(checkpoint, "server") @@ -113,11 +119,12 @@ def run_client( paths: dict, ) -> None: """Routine to run a FL client, called by `run_declearn_experiment`.""" - # Overwrite client name based on folder name + # Overwrite client name based on folder name. network.name = name - # Make the model importable + # Make the model importable and disable GPU use. + set_device_policy(gpu=False) _ = get_model(folder, model_config) - # Add checkpointer + # Add checkpointer. checkpoint = get_checkpoint(folder, expe_config) checkpoint = os.path.join(checkpoint, name) # Set up a logger: write everything to file, but filter console outputs. -- GitLab