diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py index ae51b5f22f24d09c0c8b4eb43907f0cffab5c725..84426f53ec46e2d197b34b2e039032be28d07277 100644 --- a/declearn/quickrun/run.py +++ b/declearn/quickrun/run.py @@ -53,7 +53,12 @@ from declearn.quickrun._config import ( from declearn.quickrun._parser import parse_data_folder from declearn.quickrun._split_data import split_data from declearn.test_utils import make_importable -from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger, run_as_processes +from declearn.utils import ( + LOGGING_LEVEL_MAJOR, + get_logger, + run_as_processes, + set_device_policy, +) __all__ = ["quickrun"] @@ -94,6 +99,7 @@ def run_server( expe_config: ExperimentConfig, ) -> None: """Routine to run a FL server, called by `run_declearn_experiment`.""" + set_device_policy(gpu=False) model = get_model(folder, model_config) checkpoint = get_checkpoint(folder, expe_config) checkpoint = os.path.join(checkpoint, "server") @@ -113,11 +119,12 @@ def run_client( paths: dict, ) -> None: """Routine to run a FL client, called by `run_declearn_experiment`.""" - # Overwrite client name based on folder name + # Overwrite client name based on folder name. network.name = name - # Make the model importable + # Make the model importable and disable GPU use. + set_device_policy(gpu=False) _ = get_model(folder, model_config) - # Add checkpointer + # Add checkpointer. checkpoint = get_checkpoint(folder, expe_config) checkpoint = os.path.join(checkpoint, name) # Set up a logger: write everything to file, but filter console outputs.