From f398148d6c29cf544a3c562fea7531eb339e8599 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 31 Mar 2023 16:16:58 +0200
Subject: [PATCH] Disable GPU use in quickrun mode.

---
 declearn/quickrun/run.py | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/declearn/quickrun/run.py b/declearn/quickrun/run.py
index ae51b5f2..84426f53 100644
--- a/declearn/quickrun/run.py
+++ b/declearn/quickrun/run.py
@@ -53,7 +53,12 @@ from declearn.quickrun._config import (
 from declearn.quickrun._parser import parse_data_folder
 from declearn.quickrun._split_data import split_data
 from declearn.test_utils import make_importable
-from declearn.utils import LOGGING_LEVEL_MAJOR, get_logger, run_as_processes
+from declearn.utils import (
+    LOGGING_LEVEL_MAJOR,
+    get_logger,
+    run_as_processes,
+    set_device_policy,
+)
 
 __all__ = ["quickrun"]
 
@@ -94,6 +99,7 @@ def run_server(
     expe_config: ExperimentConfig,
 ) -> None:
     """Routine to run a FL server, called by `run_declearn_experiment`."""
+    set_device_policy(gpu=False)
     model = get_model(folder, model_config)
     checkpoint = get_checkpoint(folder, expe_config)
     checkpoint = os.path.join(checkpoint, "server")
@@ -113,11 +119,12 @@ def run_client(
     paths: dict,
 ) -> None:
     """Routine to run a FL client, called by `run_declearn_experiment`."""
-    # Overwrite client name based on folder name
+    # Overwrite client name based on folder name.
     network.name = name
-    # Make the model importable
+    # Make the model importable and disable GPU use.
+    set_device_policy(gpu=False)
     _ = get_model(folder, model_config)
-    # Add checkpointer
+    # Add checkpointer.
     checkpoint = get_checkpoint(folder, expe_config)
     checkpoint = os.path.join(checkpoint, name)
     # Set up a logger: write everything to file, but filter console outputs.
-- 
GitLab