From 4cff0d377a246bb06fa186a7905692d0aa84980d Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Wed, 24 May 2023 13:35:04 +0200
Subject: [PATCH] Refactor 'run_as_processes' backend code.

---
 declearn/utils/_multiprocess.py | 120 ++++++++++++++++++++------------
 1 file changed, 77 insertions(+), 43 deletions(-)

diff --git a/declearn/utils/_multiprocess.py b/declearn/utils/_multiprocess.py
index 8e4b7ca3..ee47cddd 100644
--- a/declearn/utils/_multiprocess.py
+++ b/declearn/utils/_multiprocess.py
@@ -22,7 +22,7 @@ import multiprocessing as mp
 import sys
 import traceback
 from queue import Queue
-from typing import Any, Callable, Dict, List, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
 
 __all__ = [
     "run_as_processes",
@@ -38,19 +38,18 @@ def run_as_processes(
     Parameters
     ----------
     *routines: tuple(function, tuple(any, ...))
-        Sequence of routines that need running concurrently,
-        each formatted as a 2-elements tuple containing the
-        function to run, and a tuple storing its arguments.
+        Sequence of routines that need running concurrently, each
+        formatted as a 2-elements tuple containing the function
+        to run, and a tuple storing its arguments.
     auto_stop: bool, default=True
-        Whether to automatically interrupt all running routines
-        as soon as one failed and raised an exception. This can
-        avoid infinite runtime (e.g. if one awaits for a failed
-        routine to send information), but may also prevent some
-        exceptions from being caught due to the early stopping
-        of routines that would have failed later. Hence it may
-        be disabled in contexts where it is interesting to wait
-        for all routines to fail rather than assume that they
-        are co-dependent.
+        Whether to automatically interrupt all running routines as
+        soon as one failed and raised an exception. This can avoid
+        infinite runtime (e.g. if one awaits for a failed routine
+        to send information), but may also prevent some exceptions
+        from being caught due to the early stopping of routines that
+        would have failed later. Hence it may be disabled in contexts
+        where it is interesting to wait for all routines to fail rather
+        than assume that they are co-dependent.
 
     Returns
     -------
@@ -64,7 +63,39 @@ def run_as_processes(
     # Wrap routines into named processes and set up exceptions catching.
     queue = (
         mp.Manager().Queue()
-    )  # type: Queue[Tuple[str, Union[Any, RuntimeError]]]
+    )  # type: Queue  # Queue[Tuple[str, Union[Any, RuntimeError]]] (py >=3.9)
+    processes, names = prepare_routine_processes(routines, queue)
+    # Run the processes concurrently.
+    run_processes(processes, auto_stop)
+    # Return success flag and re-ordered outputs and exceptions.
+    success = all(process.exitcode == 0 for process in processes)
+    dequeue = dict([queue.get_nowait() for _ in range(queue.qsize())])
+    int_err = RuntimeError("Process was interrupted while running.")
+    outputs = [dequeue.get(name, int_err) for name in names]
+    return success, outputs
+
+
+def prepare_routine_processes(
+    routines: Iterable[Tuple[Callable[..., Any], Tuple[Any, ...]]],
+    queue: Queue,  # Queue[Tuple[str, Union[Any, RuntimeError]]] (py >=3.9)
+) -> Tuple[List[mp.Process], List[str]]:
+    """Wrap up routines into named unstarted processes.
+
+    Parameters
+    ----------
+    routines:
+        Iterators of (function, args) tuples to wrap as processes.
+    queue:
+        Queue where to put the routines' return value or raised exception
+        (always wrapped into a RuntimeError), together with their name.
+
+    Returns
+    -------
+    processes:
+        List of `multiprocessing.Process` instances wrapping `routines`.
+    names:
+        List of names identifying the processes (used for results collection).
+    """
     names = []  # type: List[str]
     count = {}  # type: Dict[str, int]
     processes = []  # type: List[mp.Process]
@@ -75,46 +106,24 @@ def run_as_processes(
         func = add_exception_catching(func, queue, name)
         names.append(name)
         processes.append(mp.Process(target=func, args=args, name=name))
-    # Run the processes concurrently.
-    try:
-        # Start all processes.
-        for process in processes:
-            process.start()
-        # Regularly check for any failed process and exit if so.
-        while any(process.is_alive() for process in processes):
-            if auto_stop and any(process.exitcode for process in processes):
-                break
-            # Wait for at most 1 second on the first alive process.
-            for process in processes:
-                if process.is_alive():
-                    process.join(timeout=1)
-                    break
-    # Ensure not to leave processes running in the background.
-    finally:
-        for process in processes:
-            if process.is_alive():
-                process.terminate()
-    # Return success flag and re-ordered outputs and exceptions.
-    success = all(process.exitcode == 0 for process in processes)
-    dequeue = dict([queue.get_nowait() for _ in range(queue.qsize())])
-    int_err = RuntimeError("Process was interrupted while running.")
-    outputs = [dequeue.get(name, int_err) for name in names]
-    return success, outputs
+    return processes, names
 
 
 def add_exception_catching(
     func: Callable[..., Any],
-    queue: Queue,
+    queue: Queue,  # Queue[Tuple[str, Union[Any, RuntimeError]]] (py >=3.9)
     name: str,
 ) -> Callable[..., Any]:
     """Wrap a function to catch exceptions and put them in a Queue."""
-    return functools.partial(wrapped, func=func, queue=queue, name=name)
+    return functools.partial(
+        _run_with_exception_catching, func=func, queue=queue, name=name
+    )
 
 
-def wrapped(
+def _run_with_exception_catching(
     *args: Any,
     func: Callable[..., Any],
-    queue: Queue,
+    queue: Queue,  # Queue[Tuple[str, Union[Any, RuntimeError]]] (py >=3.9)
     name: str,
 ) -> Any:
     """Call the wrapped function and catch exceptions or results."""
@@ -130,3 +139,28 @@ def wrapped(
     else:
         queue.put((name, result))
         sys.exit(0)
+
+
+def run_processes(
+    processes: List[mp.Process],
+    auto_stop: bool,
+) -> None:
+    """Run parallel processes, optionally interrupting all if any fails."""
+    try:
+        # Start all processes.
+        for process in processes:
+            process.start()
+        # Regularly check for any failed process and exit if so.
+        while any(process.is_alive() for process in processes):
+            if auto_stop and any(process.exitcode for process in processes):
+                break
+            # Wait for at most 1 second on the first alive process.
+            for process in processes:
+                if process.is_alive():
+                    process.join(timeout=1)
+                    break
+    # Ensure not to leave processes running in the background.
+    finally:
+        for process in processes:
+            if process.is_alive():
+                process.terminate()
-- 
GitLab