From 9dd5fcc4712b72631129cf4577a4350b57c58f74 Mon Sep 17 00:00:00 2001
From: ANDREY Paul <paul.andrey@inria.fr>
Date: Mon, 13 Mar 2023 14:10:47 +0000
Subject: [PATCH] Improve `run_as_processes` to capture exceptions and outputs.

* Add exceptions and outputs catching to multiprocessed routines.
* Change the output signature of `run_as_processes` to return a bool
  flag indicating success, and a list of routine-wise output value or
  RuntimeError (that may either result from an actual failure or from
  the process having been interrupted due to another one's failure).
* Add `auto_stop` parameter to enable disabling the default automated
  interruption of processes once any of them has failed.
---
 declearn/test_utils/_multiprocess.py | 88 ++++++++++++++++++++++++----
 examples/heart-uci/run.py            |  9 ++-
 test/communication/test_routines.py  |  6 +-
 test/functional/test_main.py         |  9 +--
 test/functional/test_regression.py   |  9 +--
 5 files changed, 96 insertions(+), 25 deletions(-)

diff --git a/declearn/test_utils/_multiprocess.py b/declearn/test_utils/_multiprocess.py
index 05eeb251..77881d52 100644
--- a/declearn/test_utils/_multiprocess.py
+++ b/declearn/test_utils/_multiprocess.py
@@ -18,7 +18,9 @@
 """Utils to run concurrent routines parallelly using multiprocessing."""
 
 import multiprocessing as mp
-from typing import Any, Callable, List, Optional, Tuple
+import sys
+import traceback
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 
 __all__ = [
@@ -27,8 +29,9 @@ __all__ = [
 
 
 def run_as_processes(
-    *routines: Tuple[Callable[..., Any], Tuple[Any, ...]]
-) -> List[Optional[int]]:
+    *routines: Tuple[Callable[..., Any], Tuple[Any, ...]],
+    auto_stop: bool = True,
+) -> Tuple[bool, List[Union[Any, RuntimeError]]]:
     """Run coroutines concurrently within individual processes.
 
     Parameters
@@ -37,27 +40,88 @@ def run_as_processes(
         Sequence of routines that need running concurrently,
         each formatted as a 2-elements tuple containing the
         function to run, and a tuple storing its arguments.
+    auto_stop: bool, default=True
+        Whether to automatically interrupt all running routines
+        as soon as one failed and raised an exception. This can
+        avoid infinite runtime (e.g. if one awaits for a failed
+        routine to send information), but may also prevent some
+        exceptions from being caught due to the early stopping
+        of routines that would have failed later. Hence it may
+        be disabled in contexts where it is interesting to wait
+        for all routines to fail rather than assume that they
+        are co-dependent.
 
     Returns
     -------
-    exitcodes: list[int]
-        List of exitcodes of the processes wrapping the routines.
-        If all codes are zero, then all functions ran properly.
+    success: bool
+        Whether all routines were run without raising an exception.
+    outputs: list[RuntimeError or Any]
+        List of routine-wise output value or RuntimeError exception
+        that either wraps an actual exception and its traceback, or
+        indicates that the process was interrupted while running.
     """
-    # Wrap routines as individual processes and run them concurrently.
-    processes = [mp.Process(target=func, args=args) for func, args in routines]
+    # Wrap routines into named processes and set up exceptions catching.
+    queue = mp.Queue()  # type: ignore  # mp.Queue[Union[Any, RuntimeError]]
+    names = []  # type: List[str]
+    count = {}  # type: Dict[str, int]
+    processes = []  # type: List[mp.Process]
+    for func, args in routines:
+        name = func.__name__
+        nidx = count[name] = count.get(name, 0) + 1
+        name = f"{name}-{nidx}"
+        func = add_exception_catching(func, queue, name)
+        names.append(name)
+        processes.append(mp.Process(target=func, args=args, name=name))
+    # Run the processes concurrently.
     try:
         # Start all processes.
         for process in processes:
             process.start()
         # Regularly check for any failed process and exit if so.
         while any(process.is_alive() for process in processes):
-            if any(process.exitcode for process in processes):
+            if auto_stop and any(process.exitcode for process in processes):
                 break
-            processes[0].join(timeout=1)
+            # Wait for at most 1 second on the first alive process.
+            for process in processes:
+                if process.is_alive():
+                    process.join(timeout=1)
+                    break
     # Ensure not to leave processes running in the background.
     finally:
         for process in processes:
             process.terminate()
-    # Return processes' exitcodes.
-    return [process.exitcode for process in processes]
+    # Return success flag and re-ordered outputs and exceptions.
+    success = all(process.exitcode == 0 for process in processes)
+    dequeue = dict([queue.get_nowait() for _ in range(queue.qsize())])
+    int_err = RuntimeError("Process was interrupted while running.")
+    outputs = [dequeue.get(name, int_err) for name in names]
+    return success, outputs
+
+
+def add_exception_catching(
+    func: Callable[..., Any],
+    queue: mp.Queue,
+    name: Optional[str] = None,
+) -> Callable[..., Any]:
+    """Wrap a function to catch exceptions and put them in a Queue."""
+    if not name:
+        name = func.__name__
+
+    def wrapped(*args: Any, **kwargs: Any) -> Any:
+        """Call the wrapped function and catch exceptions or results."""
+        nonlocal name, queue
+
+        try:
+            result = func(*args, **kwargs)
+        except Exception as exc:  # pylint: disable=broad-exception-caught
+            err = RuntimeError(
+                f"Exception of type {type(exc)} occurred:\n"
+                "".join(traceback.format_exception(type(exc), exc, tb=None))
+            )  # future: `traceback.format_exception(exc)` (py >=3.10)
+            queue.put((name, err))
+            sys.exit(1)
+        else:
+            queue.put((name, result))
+            sys.exit(0)
+
+    return wrapped
diff --git a/examples/heart-uci/run.py b/examples/heart-uci/run.py
index 5a6e4db3..4a5ac9ad 100644
--- a/examples/heart-uci/run.py
+++ b/examples/heart-uci/run.py
@@ -51,9 +51,12 @@ def run_demo(
             (run_client, (name, ca_cert)) for name in NAMES[:nb_clients]
         ]
         # Run routines in isolated processes. Raise if any failed.
-        exitcodes = run_as_processes(server, *clients)
-        if any(code != 0 for code in exitcodes):
-            raise RuntimeError("Something went wrong during the demo.")
+        success, outp = run_as_processes(server, *clients)
+        if not success:
+            raise RuntimeError(
+                "Something went wrong during the demo. Exceptions caught:\n"
+                "\n".join(str(e) for e in outp if isinstance(e, RuntimeError))
+            )
 
 
 if __name__ == "__main__":
diff --git a/test/communication/test_routines.py b/test/communication/test_routines.py
index c1d746f2..b5dcc35a 100644
--- a/test/communication/test_routines.py
+++ b/test/communication/test_routines.py
@@ -127,9 +127,11 @@ def run_test_routines(
     routines = [_build_server_func(*args)]
     routines.extend(_build_client_funcs(*args))
     # Run the former using isolated processes.
-    exitcodes = run_as_processes(*routines)
+    success, outputs = run_as_processes(*routines)
     # Assert that all processes terminated properly.
-    assert all(code == 0 for code in exitcodes)
+    assert success, "Routines failed:\n" + "\n".join(
+        [str(exc) for exc in outputs if isinstance(exc, RuntimeError)]
+    )
 
 
 def _build_server_func(
diff --git a/test/functional/test_main.py b/test/functional/test_main.py
index f11ee0ef..5e5a59e3 100644
--- a/test/functional/test_main.py
+++ b/test/functional/test_main.py
@@ -264,10 +264,11 @@ def run_test_case(
         (test_case.run_federated_client, (f"cli_{i}",))
         for i in range(nb_clients)
     ]
-    # Run them concurrently using multiprocessing.
-    exitcodes = run_as_processes(server, *clients)
-    # Verify that all processes ended without error nor interruption.
-    assert all(code == 0 for code in exitcodes)
+    # Run them concurrently using multiprocessing. Assert none failed.
+    success, outputs = run_as_processes(server, *clients)
+    assert success, "Test case failed:\n" + "\n".join(
+        str(exc) for exc in outputs if isinstance(exc, RuntimeError)
+    )
 
 
 @pytest.mark.parametrize("strategy", ["FedAvg", "FedAvgM", "Scaffold"])
diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py
index 70a52924..19ac7803 100644
--- a/test/functional/test_regression.py
+++ b/test/functional/test_regression.py
@@ -194,15 +194,16 @@ def test_declearn_experiment(
             client = (_client_routine, (data[0], data[1], f"client_{i}"))
             p_client.append(client)
         # Run each and every process in parallel.
-        exitcodes = run_as_processes(p_server, *p_client)
-        if not all(code == 0 for code in exitcodes):
-            raise RuntimeError("The FL experiment failed.")
+        success, outputs = run_as_processes(p_server, *p_client)
+        assert success, "The FL process failed:\n" + "\n".join(
+            str(exc) for exc in outputs if isinstance(exc, RuntimeError)
+        )
         # Assert convergence
         with open(f"{folder}/metrics.json", encoding="utf-8") as file:
             r2_dict = json.load(file)
             last_r2_dict = r2_dict.get(max(r2_dict.keys()))
             final_r2 = float(last_r2_dict.get("r2"))
-            assert final_r2 > R2_THRESHOLD
+            assert final_r2 > R2_THRESHOLD, "The FL training did not converge"
 
 
 def _server_routine(
-- 
GitLab