diff --git a/declearn/utils/_multiprocess.py b/declearn/utils/_multiprocess.py index ee47cdddd4352867c32cdc4bb885f7d01b2160a7..314c71638ba4819c449849e2e9b3e4de1853ea2d 100644 --- a/declearn/utils/_multiprocess.py +++ b/declearn/utils/_multiprocess.py @@ -30,7 +30,11 @@ __all__ = [ def run_as_processes( - *routines: Tuple[Callable[..., Any], Tuple[Any, ...]], + *routines: Union[ + Tuple[Callable[..., Any], Tuple[Any, ...]], + Tuple[Callable[..., Any], Dict[str, Any]], + Tuple[Callable[..., Any], Tuple[Any, ...], Dict[str, Any]], + ], auto_stop: bool = True, ) -> Tuple[bool, List[Union[Any, RuntimeError]]]: """Run coroutines concurrently within individual processes. @@ -39,8 +43,13 @@ def run_as_processes( ---------- *routines: tuple(function, tuple(any, ...)) Sequence of routines that need running concurrently, each - formatted as a 2-elements tuple containing the function - to run, and a tuple storing its arguments. + formatted as either: + - a 3-elements tuple containing the function to run, + a tuple of positional args and a dict of kwargs. + - a 2-elements tuple containing the function to run, + and a tuple storing its (positional) arguments. + - a 2-elements tuple containing the function to run, + and a dict storing its keyword arguments. auto_stop: bool, default=True Whether to automatically interrupt all running routines as soon as one failed and raised an exception. This can avoid @@ -76,7 +85,13 @@ def run_as_processes( def prepare_routine_processes( - routines: Iterable[Tuple[Callable[..., Any], Tuple[Any, ...]]], + routines: Iterable[ + Union[ + Tuple[Callable[..., Any], Tuple[Any, ...]], + Tuple[Callable[..., Any], Dict[str, Any]], + Tuple[Callable[..., Any], Tuple[Any, ...], Dict[str, Any]], + ] + ], queue: Queue, # Queue[Tuple[str, Union[Any, RuntimeError]]] (py >=3.9) ) -> Tuple[List[mp.Process], List[str]]: """Wrap up routines into named unstarted processes. @@ -89,6 +104,11 @@ def prepare_routine_processes( Queue where to put the routines' return value or raised exception (always wrapped into a RuntimeError), together with their name. + Raises + ------ + TypeError + If the inputs do not match the expected type specifications. + Returns ------- processes: @@ -99,16 +119,80 @@ def prepare_routine_processes( names = [] # type: List[str] count = {} # type: Dict[str, int] processes = [] # type: List[mp.Process] - for func, args in routines: + for routine in routines: + func, args, kwargs = parse_routine_specification(routine) name = func.__name__ nidx = count[name] = count.get(name, 0) + 1 name = f"{name}-{nidx}" func = add_exception_catching(func, queue, name) names.append(name) - processes.append(mp.Process(target=func, args=args, name=name)) + processes.append( + mp.Process(target=func, args=args, kwargs=kwargs, name=name) + ) return processes, names +def parse_routine_specification( + routine: Union[ + Tuple[Callable[..., Any], Tuple[Any, ...]], + Tuple[Callable[..., Any], Dict[str, Any]], + Tuple[Callable[..., Any], Tuple[Any, ...], Dict[str, Any]], + ] +) -> Tuple[Callable[..., Any], Tuple[Any, ...], Dict[str, Any]]: + """Type-check and unpack a given routine specification. + + Raises + ------ + TypeError + If the inputs do not match the expected type specifications. + + Returns + ------- + func: + Callable to wrap as a process. + args: + Tuple of positional arguments to `func`. May be empty. + kwargs: + Dict of keyword arguments to `func`. May be empty. + """ + # Type-check the overall input. + if not (isinstance(routine, (tuple, list)) and (len(routine) in (2, 3))): + raise TypeError( + "Received an unproper routine specification: should be a 2- " + "or 3-element tuple." + ) + # Check that the first argument is callable. + func = routine[0] + if not callable(func): + raise TypeError( + "The first argument of a routine specification should be callable." + ) + # Case of a 2-elements tuple: may be (func, args) or (func, kwargs). + if len(routine) == 2: + if isinstance(routine[1], tuple): + args = routine[1] + kwargs = {} + elif isinstance(routine[1], dict): + args = tuple() + kwargs = routine[1] + else: + raise TypeError( + "Received an unproper routine specification: 2nd element " + f"should be a tuple or dict, not '{type(routine[1])}'." + ) + # Case of a 3-elements tuple: should be (func, args, kwargs). + else: + args = routine[1] # type: ignore # verified below + kwargs = routine[2] # type: ignore # verified below + if not (isinstance(args, tuple) and isinstance(kwargs, dict)): + raise TypeError( + "Received an unproper routine specification: 2nd and 3rd " + f"elements should be a tuple and a dict, not '{type(args)}'" + f" and '{type(kwargs)}'." + ) + return func, args, kwargs + + def add_exception_catching( func: Callable[..., Any], queue: Queue, # Queue[Tuple[str, Union[Any, RuntimeError]]] (py >=3.9) @@ -125,10 +209,11 @@ def _run_with_exception_catching( func: Callable[..., Any], queue: Queue, # Queue[Tuple[str, Union[Any, RuntimeError]]] (py >=3.9) name: str, + **kwargs: Any, ) -> Any: """Call the wrapped function and catch exceptions or results.""" try: - result = func(*args) + result = func(*args, **kwargs) except Exception as exc: # pylint: disable=broad-exception-caught err = RuntimeError( f"Exception of type {type(exc)} occurred:\n"