From d02b880ed7c59aab0ffda0519f1cef2440103da5 Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 5 May 2023 10:20:55 +0200 Subject: [PATCH] Minor formatting revisions to 'HaikuModel'. --- declearn/model/haiku/_model.py | 60 +++++++++++++++++----------------- test/model/test_haiku.py | 1 + 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py index ea3a64cf..63924478 100644 --- a/declearn/model/haiku/_model.py +++ b/declearn/model/haiku/_model.py @@ -18,20 +18,14 @@ """Model subclass to wrap Haiku models.""" import functools +import inspect import io import warnings from copy import deepcopy from random import SystemRandom from typing import ( - Any, - Callable, - Dict, - Generator, - List, - Optional, - Set, - Tuple, - Union, + # fmt: off + Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union ) import haiku as hk @@ -209,7 +203,7 @@ class HaikuModel(Model): def get_named_weights( self, trainable: bool = False, - ) -> Any: + ) -> Dict[str, Dict[str, jax.Array]]: """Access the weights of the haiku model as a nested dict. Return type is any due to `jax.tree_util.tree_unflatten`. @@ -217,7 +211,6 @@ class HaikuModel(Model): trainable: bool, default=False If True, restrict the returned weights to the trainable ones, else return all weights. - """ assert self._treedef is not None, "uninitialized JaxModel" params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves) @@ -231,8 +224,7 @@ class HaikuModel(Model): return params def _get_fixed_named_weights(self) -> Any: - """Access the fixed weights of the haiku model as a nested dict, - if any. + """Access the fixed weights of the model as a nested dict, if any. Return type is any due to `jax.tree_util.tree_unflatten`. """ @@ -344,11 +336,13 @@ class HaikuModel(Model): if not self._initialized: raise ValueError("Model needs to be initialized first") if isinstance(criterion, list) and isinstance(criterion[0], int): - self._trainable = criterion # type: ignore + self._trainable = criterion else: self._trainable = [] # reset if needed - if isinstance(criterion, Callable): # type: ignore - include_fn = criterion + if inspect.isfunction(criterion): + include_fn = ( + criterion + ) # type: Callable[[str, str, jax.Array], bool] elif isinstance(criterion, dict): include_fn = self._build_include_fn(criterion) else: @@ -359,13 +353,17 @@ class HaikuModel(Model): for idx, (layer, name, value) in enumerate( self._traverse_params() ): - if include_fn(layer, name, value): # type: ignore + if include_fn(layer, name, value): self._trainable.append(idx) - def _traverse_params(self) -> Generator: - """Traverse the pytree of model named weight left-to-right, - depth-first, returning a generator""" - params = self.get_named_weights() + def _traverse_params(self) -> Iterator[Tuple[str, str, jax.Array]]: + """Traverse the pytree of a model's named weights. + + Yield (layer_name, weight_name, weight_value) tuples from + traversing the pytree left-to-right, depth-first. + """ + assert self._treedef is not None, "uninitialized JaxModel" + params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves) for layer in params: bundle = params[layer] for name in bundle: @@ -373,11 +371,13 @@ class HaikuModel(Model): yield layer, name, value @staticmethod - def _build_include_fn(node_names: Dict[str, Dict[str, Any]]) -> Callable: - """Build an equality checking function, conforming to what is - expected at traversal""" + def _build_include_fn( + node_names: Dict[str, Dict[str, Any]], + ) -> Callable[[str, str, jax.Array], bool]: + """Build an equality-checking function for parameters' traversal.""" - def include_fn(layer, name, value): # pylint: disable=W0613 + def include_fn(layer: str, name: str, value: jax.Array) -> bool: + # mandatory signature; pylint: disable=unused-argument if layer in list(node_names.keys()): return name in list(node_names[layer].keys()) return False @@ -442,7 +442,7 @@ class HaikuModel(Model): next(self._rng_gen), inputs, ) - # Flatten the gradients and return them in a Vector container + # Flatten the gradients and return them in a Vector container. flat_grad, _ = jax.tree_util.tree_flatten(grads) return JaxNumpyVector({str(k): v for k, v in enumerate(flat_grad)}) @@ -479,7 +479,7 @@ class HaikuModel(Model): def _clipped_grad_fn( self, ) -> Callable[ - [hk.Params, hk.Params, jax.Array, JaxBatch, float], jax.Array + [hk.Params, hk.Params, jax.Array, JaxBatch, float], List[jax.Array] ]: """Lazy-built jax function to compute clipped sample-wise gradients. @@ -524,6 +524,7 @@ class HaikuModel(Model): return jnp.array(data) # pylint: disable=no-member raise TypeError("HaikuModel requires numpy or jax.numpy data.") + # similar code to TorchModel; pylint: disable=duplicate-code # Convert batched data to jax Arrays. inputs, y_true, s_wght = batch if not isinstance(inputs, (tuple, list)): @@ -544,7 +545,7 @@ class HaikuModel(Model): def compute_batch_predictions( self, batch: Batch, - ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray],]: + ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: inputs, y_true, s_wght = self._unpack_batch(batch) if y_true is None: raise TypeError( @@ -558,8 +559,7 @@ class HaikuModel(Model): self._transformed_model.apply(params, next(self._rng_gen), *inputs) ) y_true = np.asarray(y_true) # type: ignore - if isinstance(s_wght, jax.Array): - s_wght = np.asarray(s_wght) # type: ignore + s_wght = None if s_wght is None else np.asarray(s_wght) # type: ignore return y_true, y_pred, s_wght # type: ignore def loss_function( diff --git a/test/model/test_haiku.py b/test/model/test_haiku.py index f9231ded..8083dcda 100644 --- a/test/model/test_haiku.py +++ b/test/model/test_haiku.py @@ -259,6 +259,7 @@ class TestHaikuModel(ModelTestSuite): criterion_type: str, ) -> None: """Check that `set_weights` behaves properly with frozen weights.""" + # similar code to TorchModel tests; pylint: disable=duplicate-code # Setup a model with some frozen weights, and gather trainable ones. model = test_case.model criterion = test_case.get_trainable_criterion(criterion_type) -- GitLab