diff --git a/declearn/utils/_multiprocess.py b/declearn/utils/_multiprocess.py index 5214e7bf556d80aac0cf08e98e04694a9ece4482..933f355542ea299798caa812636d32c36b6a5e2d 100644 --- a/declearn/utils/_multiprocess.py +++ b/declearn/utils/_multiprocess.py @@ -17,9 +17,11 @@ """Utils to run concurrent routines parallelly using multiprocessing.""" +import functools import multiprocessing as mp import sys import traceback +from multiprocessing.queues import Queue from typing import Any, Callable, Dict, List, Optional, Tuple, Union __all__ = [ @@ -60,7 +62,7 @@ def run_as_processes( indicates that the process was interrupted while running. """ # Wrap routines into named processes and set up exceptions catching. - queue = mp.Queue() # type: ignore # mp.Queue[Union[Any, RuntimeError]] + queue = mp.Manager().Queue() # type: ignore names = [] # type: List[str] count = {} # type: Dict[str, int] processes = [] # type: List[mp.Process] @@ -68,7 +70,7 @@ def run_as_processes( name = func.__name__ nidx = count[name] = count.get(name, 0) + 1 name = f"{name}-{nidx}" - func = add_exception_catching(func, queue, name) + func = add_exception_catching(func, queue, name) # type: ignore names.append(name) processes.append(mp.Process(target=func, args=args, name=name)) # Run the processes concurrently. @@ -100,28 +102,31 @@ def run_as_processes( def add_exception_catching( func: Callable[..., Any], - queue: mp.Queue, + queue: Queue, name: Optional[str] = None, ) -> Callable[..., Any]: """Wrap a function to catch exceptions and put them in a Queue.""" if not name: name = func.__name__ + return functools.partial(wrapped, func=func, queue=queue, name=name) - def wrapped(*args: Any, **kwargs: Any) -> Any: - """Call the wrapped function and catch exceptions or results.""" - nonlocal name, queue - try: - result = func(*args, **kwargs) - except Exception as exc: # pylint: disable=broad-exception-caught - err = RuntimeError( - f"Exception of type {type(exc)} occurred:\n" - "".join(traceback.format_exception(type(exc), exc, tb=None)) - ) # future: `traceback.format_exception(exc)` (py >=3.10) - queue.put((name, err)) - sys.exit(1) - else: - queue.put((name, result)) - sys.exit(0) - - return wrapped +def wrapped( + *args: Any, + func: Callable[..., Any], + queue: Queue, + name: str, +) -> Any: + """Call the wrapped function and catch exceptions or results.""" + try: + result = func(*args) + except Exception as exc: # pylint: disable=broad-exception-caught + err = RuntimeError( + f"Exception of type {type(exc)} occurred:\n" + "".join(traceback.format_exception(type(exc), exc, tb=None)) + ) # future: `traceback.format_exception(exc)` (py >=3.10) + queue.put((name, err)) + sys.exit(1) + else: + queue.put((name, result)) + sys.exit(0)