From 9dd5fcc4712b72631129cf4577a4350b57c58f74 Mon Sep 17 00:00:00 2001 From: ANDREY Paul <paul.andrey@inria.fr> Date: Mon, 13 Mar 2023 14:10:47 +0000 Subject: [PATCH] Improve `run_as_processes` to capture exceptions and outputs. * Add exceptions and outputs catching to multiprocessed routines. * Change the output signature of `run_as_processes` to return a bool flag indicating success, and a list of routine-wise output value or RuntimeError (that may either result from an actual failure or from the process having been interrupted due to another one's failure). * Add `auto_stop` parameter to enable disabling the default automated interruption of processes once any of them has failed. --- declearn/test_utils/_multiprocess.py | 88 ++++++++++++++++++++++++---- examples/heart-uci/run.py | 9 ++- test/communication/test_routines.py | 6 +- test/functional/test_main.py | 9 +-- test/functional/test_regression.py | 9 +-- 5 files changed, 96 insertions(+), 25 deletions(-) diff --git a/declearn/test_utils/_multiprocess.py b/declearn/test_utils/_multiprocess.py index 05eeb251..77881d52 100644 --- a/declearn/test_utils/_multiprocess.py +++ b/declearn/test_utils/_multiprocess.py @@ -18,7 +18,9 @@ """Utils to run concurrent routines parallelly using multiprocessing.""" import multiprocessing as mp -from typing import Any, Callable, List, Optional, Tuple +import sys +import traceback +from typing import Any, Callable, Dict, List, Optional, Tuple, Union __all__ = [ @@ -27,8 +29,9 @@ __all__ = [ def run_as_processes( - *routines: Tuple[Callable[..., Any], Tuple[Any, ...]] -) -> List[Optional[int]]: + *routines: Tuple[Callable[..., Any], Tuple[Any, ...]], + auto_stop: bool = True, +) -> Tuple[bool, List[Union[Any, RuntimeError]]]: """Run coroutines concurrently within individual processes. Parameters @@ -37,27 +40,88 @@ def run_as_processes( Sequence of routines that need running concurrently, each formatted as a 2-elements tuple containing the function to run, and a tuple storing its arguments. + auto_stop: bool, default=True + Whether to automatically interrupt all running routines + as soon as one failed and raised an exception. This can + avoid infinite runtime (e.g. if one awaits for a failed + routine to send information), but may also prevent some + exceptions from being caught due to the early stopping + of routines that would have failed later. Hence it may + be disabled in contexts where it is interesting to wait + for all routines to fail rather than assume that they + are co-dependent. Returns ------- - exitcodes: list[int] - List of exitcodes of the processes wrapping the routines. - If all codes are zero, then all functions ran properly. + success: bool + Whether all routines were run without raising an exception. + outputs: list[RuntimeError or Any] + List of routine-wise output value or RuntimeError exception + that either wraps an actual exception and its traceback, or + indicates that the process was interrupted while running. """ - # Wrap routines as individual processes and run them concurrently. - processes = [mp.Process(target=func, args=args) for func, args in routines] + # Wrap routines into named processes and set up exceptions catching. + queue = mp.Queue() # type: ignore # mp.Queue[Union[Any, RuntimeError]] + names = [] # type: List[str] + count = {} # type: Dict[str, int] + processes = [] # type: List[mp.Process] + for func, args in routines: + name = func.__name__ + nidx = count[name] = count.get(name, 0) + 1 + name = f"{name}-{nidx}" + func = add_exception_catching(func, queue, name) + names.append(name) + processes.append(mp.Process(target=func, args=args, name=name)) + # Run the processes concurrently. try: # Start all processes. for process in processes: process.start() # Regularly check for any failed process and exit if so. while any(process.is_alive() for process in processes): - if any(process.exitcode for process in processes): + if auto_stop and any(process.exitcode for process in processes): break - processes[0].join(timeout=1) + # Wait for at most 1 second on the first alive process. + for process in processes: + if process.is_alive(): + process.join(timeout=1) + break # Ensure not to leave processes running in the background. finally: for process in processes: process.terminate() - # Return processes' exitcodes. - return [process.exitcode for process in processes] + # Return success flag and re-ordered outputs and exceptions. + success = all(process.exitcode == 0 for process in processes) + dequeue = dict([queue.get_nowait() for _ in range(queue.qsize())]) + int_err = RuntimeError("Process was interrupted while running.") + outputs = [dequeue.get(name, int_err) for name in names] + return success, outputs + + +def add_exception_catching( + func: Callable[..., Any], + queue: mp.Queue, + name: Optional[str] = None, +) -> Callable[..., Any]: + """Wrap a function to catch exceptions and put them in a Queue.""" + if not name: + name = func.__name__ + + def wrapped(*args: Any, **kwargs: Any) -> Any: + """Call the wrapped function and catch exceptions or results.""" + nonlocal name, queue + + try: + result = func(*args, **kwargs) + except Exception as exc: # pylint: disable=broad-exception-caught + err = RuntimeError( + f"Exception of type {type(exc)} occurred:\n" + "".join(traceback.format_exception(type(exc), exc, tb=None)) + ) # future: `traceback.format_exception(exc)` (py >=3.10) + queue.put((name, err)) + sys.exit(1) + else: + queue.put((name, result)) + sys.exit(0) + + return wrapped diff --git a/examples/heart-uci/run.py b/examples/heart-uci/run.py index 5a6e4db3..4a5ac9ad 100644 --- a/examples/heart-uci/run.py +++ b/examples/heart-uci/run.py @@ -51,9 +51,12 @@ def run_demo( (run_client, (name, ca_cert)) for name in NAMES[:nb_clients] ] # Run routines in isolated processes. Raise if any failed. - exitcodes = run_as_processes(server, *clients) - if any(code != 0 for code in exitcodes): - raise RuntimeError("Something went wrong during the demo.") + success, outp = run_as_processes(server, *clients) + if not success: + raise RuntimeError( + "Something went wrong during the demo. Exceptions caught:\n" + "\n".join(str(e) for e in outp if isinstance(e, RuntimeError)) + ) if __name__ == "__main__": diff --git a/test/communication/test_routines.py b/test/communication/test_routines.py index c1d746f2..b5dcc35a 100644 --- a/test/communication/test_routines.py +++ b/test/communication/test_routines.py @@ -127,9 +127,11 @@ def run_test_routines( routines = [_build_server_func(*args)] routines.extend(_build_client_funcs(*args)) # Run the former using isolated processes. - exitcodes = run_as_processes(*routines) + success, outputs = run_as_processes(*routines) # Assert that all processes terminated properly. - assert all(code == 0 for code in exitcodes) + assert success, "Routines failed:\n" + "\n".join( + [str(exc) for exc in outputs if isinstance(exc, RuntimeError)] + ) def _build_server_func( diff --git a/test/functional/test_main.py b/test/functional/test_main.py index f11ee0ef..5e5a59e3 100644 --- a/test/functional/test_main.py +++ b/test/functional/test_main.py @@ -264,10 +264,11 @@ def run_test_case( (test_case.run_federated_client, (f"cli_{i}",)) for i in range(nb_clients) ] - # Run them concurrently using multiprocessing. - exitcodes = run_as_processes(server, *clients) - # Verify that all processes ended without error nor interruption. - assert all(code == 0 for code in exitcodes) + # Run them concurrently using multiprocessing. Assert none failed. + success, outputs = run_as_processes(server, *clients) + assert success, "Test case failed:\n" + "\n".join( + str(exc) for exc in outputs if isinstance(exc, RuntimeError) + ) @pytest.mark.parametrize("strategy", ["FedAvg", "FedAvgM", "Scaffold"]) diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py index 70a52924..19ac7803 100644 --- a/test/functional/test_regression.py +++ b/test/functional/test_regression.py @@ -194,15 +194,16 @@ def test_declearn_experiment( client = (_client_routine, (data[0], data[1], f"client_{i}")) p_client.append(client) # Run each and every process in parallel. - exitcodes = run_as_processes(p_server, *p_client) - if not all(code == 0 for code in exitcodes): - raise RuntimeError("The FL experiment failed.") + success, outputs = run_as_processes(p_server, *p_client) + assert success, "The FL process failed:\n" + "\n".join( + str(exc) for exc in outputs if isinstance(exc, RuntimeError) + ) # Assert convergence with open(f"{folder}/metrics.json", encoding="utf-8") as file: r2_dict = json.load(file) last_r2_dict = r2_dict.get(max(r2_dict.keys())) final_r2 = float(last_r2_dict.get("r2")) - assert final_r2 > R2_THRESHOLD + assert final_r2 > R2_THRESHOLD, "The FL training did not converge" def _server_routine( -- GitLab