Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 9dd5fcc4 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Improve `run_as_processes` to capture exceptions and outputs.

* Add exceptions and outputs catching to multiprocessed routines.
* Change the output signature of `run_as_processes` to return a bool
  flag indicating success, and a list of routine-wise output value or
  RuntimeError (that may either result from an actual failure or from
  the process having been interrupted due to another one's failure).
* Add `auto_stop` parameter to enable disabling the default automated
  interruption of processes once any of them has failed.
parent 35ec6c72
No related branches found
No related tags found
1 merge request!37Improve `run_as_processes` to capture exceptions and outputs.
......@@ -18,7 +18,9 @@
"""Utils to run concurrent routines parallelly using multiprocessing."""
import multiprocessing as mp
from typing import Any, Callable, List, Optional, Tuple
import sys
import traceback
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
__all__ = [
......@@ -27,8 +29,9 @@ __all__ = [
def run_as_processes(
*routines: Tuple[Callable[..., Any], Tuple[Any, ...]]
) -> List[Optional[int]]:
*routines: Tuple[Callable[..., Any], Tuple[Any, ...]],
auto_stop: bool = True,
) -> Tuple[bool, List[Union[Any, RuntimeError]]]:
"""Run coroutines concurrently within individual processes.
Parameters
......@@ -37,27 +40,88 @@ def run_as_processes(
Sequence of routines that need running concurrently,
each formatted as a 2-elements tuple containing the
function to run, and a tuple storing its arguments.
auto_stop: bool, default=True
Whether to automatically interrupt all running routines
as soon as one failed and raised an exception. This can
avoid infinite runtime (e.g. if one awaits for a failed
routine to send information), but may also prevent some
exceptions from being caught due to the early stopping
of routines that would have failed later. Hence it may
be disabled in contexts where it is interesting to wait
for all routines to fail rather than assume that they
are co-dependent.
Returns
-------
exitcodes: list[int]
List of exitcodes of the processes wrapping the routines.
If all codes are zero, then all functions ran properly.
success: bool
Whether all routines were run without raising an exception.
outputs: list[RuntimeError or Any]
List of routine-wise output value or RuntimeError exception
that either wraps an actual exception and its traceback, or
indicates that the process was interrupted while running.
"""
# Wrap routines as individual processes and run them concurrently.
processes = [mp.Process(target=func, args=args) for func, args in routines]
# Wrap routines into named processes and set up exceptions catching.
queue = mp.Queue() # type: ignore # mp.Queue[Union[Any, RuntimeError]]
names = [] # type: List[str]
count = {} # type: Dict[str, int]
processes = [] # type: List[mp.Process]
for func, args in routines:
name = func.__name__
nidx = count[name] = count.get(name, 0) + 1
name = f"{name}-{nidx}"
func = add_exception_catching(func, queue, name)
names.append(name)
processes.append(mp.Process(target=func, args=args, name=name))
# Run the processes concurrently.
try:
# Start all processes.
for process in processes:
process.start()
# Regularly check for any failed process and exit if so.
while any(process.is_alive() for process in processes):
if any(process.exitcode for process in processes):
if auto_stop and any(process.exitcode for process in processes):
break
processes[0].join(timeout=1)
# Wait for at most 1 second on the first alive process.
for process in processes:
if process.is_alive():
process.join(timeout=1)
break
# Ensure not to leave processes running in the background.
finally:
for process in processes:
process.terminate()
# Return processes' exitcodes.
return [process.exitcode for process in processes]
# Return success flag and re-ordered outputs and exceptions.
success = all(process.exitcode == 0 for process in processes)
dequeue = dict([queue.get_nowait() for _ in range(queue.qsize())])
int_err = RuntimeError("Process was interrupted while running.")
outputs = [dequeue.get(name, int_err) for name in names]
return success, outputs
def add_exception_catching(
func: Callable[..., Any],
queue: mp.Queue,
name: Optional[str] = None,
) -> Callable[..., Any]:
"""Wrap a function to catch exceptions and put them in a Queue."""
if not name:
name = func.__name__
def wrapped(*args: Any, **kwargs: Any) -> Any:
"""Call the wrapped function and catch exceptions or results."""
nonlocal name, queue
try:
result = func(*args, **kwargs)
except Exception as exc: # pylint: disable=broad-exception-caught
err = RuntimeError(
f"Exception of type {type(exc)} occurred:\n"
"".join(traceback.format_exception(type(exc), exc, tb=None))
) # future: `traceback.format_exception(exc)` (py >=3.10)
queue.put((name, err))
sys.exit(1)
else:
queue.put((name, result))
sys.exit(0)
return wrapped
......@@ -51,9 +51,12 @@ def run_demo(
(run_client, (name, ca_cert)) for name in NAMES[:nb_clients]
]
# Run routines in isolated processes. Raise if any failed.
exitcodes = run_as_processes(server, *clients)
if any(code != 0 for code in exitcodes):
raise RuntimeError("Something went wrong during the demo.")
success, outp = run_as_processes(server, *clients)
if not success:
raise RuntimeError(
"Something went wrong during the demo. Exceptions caught:\n"
"\n".join(str(e) for e in outp if isinstance(e, RuntimeError))
)
if __name__ == "__main__":
......
......@@ -127,9 +127,11 @@ def run_test_routines(
routines = [_build_server_func(*args)]
routines.extend(_build_client_funcs(*args))
# Run the former using isolated processes.
exitcodes = run_as_processes(*routines)
success, outputs = run_as_processes(*routines)
# Assert that all processes terminated properly.
assert all(code == 0 for code in exitcodes)
assert success, "Routines failed:\n" + "\n".join(
[str(exc) for exc in outputs if isinstance(exc, RuntimeError)]
)
def _build_server_func(
......
......@@ -264,10 +264,11 @@ def run_test_case(
(test_case.run_federated_client, (f"cli_{i}",))
for i in range(nb_clients)
]
# Run them concurrently using multiprocessing.
exitcodes = run_as_processes(server, *clients)
# Verify that all processes ended without error nor interruption.
assert all(code == 0 for code in exitcodes)
# Run them concurrently using multiprocessing. Assert none failed.
success, outputs = run_as_processes(server, *clients)
assert success, "Test case failed:\n" + "\n".join(
str(exc) for exc in outputs if isinstance(exc, RuntimeError)
)
@pytest.mark.parametrize("strategy", ["FedAvg", "FedAvgM", "Scaffold"])
......
......@@ -194,15 +194,16 @@ def test_declearn_experiment(
client = (_client_routine, (data[0], data[1], f"client_{i}"))
p_client.append(client)
# Run each and every process in parallel.
exitcodes = run_as_processes(p_server, *p_client)
if not all(code == 0 for code in exitcodes):
raise RuntimeError("The FL experiment failed.")
success, outputs = run_as_processes(p_server, *p_client)
assert success, "The FL process failed:\n" + "\n".join(
str(exc) for exc in outputs if isinstance(exc, RuntimeError)
)
# Assert convergence
with open(f"{folder}/metrics.json", encoding="utf-8") as file:
r2_dict = json.load(file)
last_r2_dict = r2_dict.get(max(r2_dict.keys()))
final_r2 = float(last_r2_dict.get("r2"))
assert final_r2 > R2_THRESHOLD
assert final_r2 > R2_THRESHOLD, "The FL training did not converge"
def _server_routine(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment