From d02b880ed7c59aab0ffda0519f1cef2440103da5 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Fri, 5 May 2023 10:20:55 +0200
Subject: [PATCH] Minor formatting revisions to 'HaikuModel'.

---
 declearn/model/haiku/_model.py | 60 +++++++++++++++++-----------------
 test/model/test_haiku.py       |  1 +
 2 files changed, 31 insertions(+), 30 deletions(-)

diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py
index ea3a64cf..63924478 100644
--- a/declearn/model/haiku/_model.py
+++ b/declearn/model/haiku/_model.py
@@ -18,20 +18,14 @@
 """Model subclass to wrap Haiku models."""
 
 import functools
+import inspect
 import io
 import warnings
 from copy import deepcopy
 from random import SystemRandom
 from typing import (
-    Any,
-    Callable,
-    Dict,
-    Generator,
-    List,
-    Optional,
-    Set,
-    Tuple,
-    Union,
+    # fmt: off
+    Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
 )
 
 import haiku as hk
@@ -209,7 +203,7 @@ class HaikuModel(Model):
     def get_named_weights(
         self,
         trainable: bool = False,
-    ) -> Any:
+    ) -> Dict[str, Dict[str, jax.Array]]:
         """Access the weights of the haiku model as a nested dict.
 
         Return type is any due to `jax.tree_util.tree_unflatten`.
@@ -217,7 +211,6 @@ class HaikuModel(Model):
         trainable: bool, default=False
             If True, restrict the returned weights to the trainable ones,
             else return all weights.
-
         """
         assert self._treedef is not None, "uninitialized JaxModel"
         params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves)
@@ -231,8 +224,7 @@ class HaikuModel(Model):
         return params
 
     def _get_fixed_named_weights(self) -> Any:
-        """Access the fixed weights of the haiku model as a nested dict,
-        if any.
+        """Access the fixed weights of the model as a nested dict, if any.
 
         Return type is any due to `jax.tree_util.tree_unflatten`.
         """
@@ -344,11 +336,13 @@ class HaikuModel(Model):
         if not self._initialized:
             raise ValueError("Model needs to be initialized first")
         if isinstance(criterion, list) and isinstance(criterion[0], int):
-            self._trainable = criterion  # type: ignore
+            self._trainable = criterion
         else:
             self._trainable = []  # reset if needed
-            if isinstance(criterion, Callable):  # type: ignore
-                include_fn = criterion
+            if inspect.isfunction(criterion):
+                include_fn = (
+                    criterion
+                )  # type: Callable[[str, str, jax.Array], bool]
             elif isinstance(criterion, dict):
                 include_fn = self._build_include_fn(criterion)
             else:
@@ -359,13 +353,17 @@ class HaikuModel(Model):
             for idx, (layer, name, value) in enumerate(
                 self._traverse_params()
             ):
-                if include_fn(layer, name, value):  # type: ignore
+                if include_fn(layer, name, value):
                     self._trainable.append(idx)
 
-    def _traverse_params(self) -> Generator:
-        """Traverse the pytree of model named weight left-to-right,
-        depth-first, returning a generator"""
-        params = self.get_named_weights()
+    def _traverse_params(self) -> Iterator[Tuple[str, str, jax.Array]]:
+        """Traverse the pytree of a model's named weights.
+
+        Yield (layer_name, weight_name, weight_value) tuples from
+        traversing the pytree left-to-right, depth-first.
+        """
+        assert self._treedef is not None, "uninitialized JaxModel"
+        params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves)
         for layer in params:
             bundle = params[layer]
             for name in bundle:
@@ -373,11 +371,13 @@ class HaikuModel(Model):
                 yield layer, name, value
 
     @staticmethod
-    def _build_include_fn(node_names: Dict[str, Dict[str, Any]]) -> Callable:
-        """Build an equality checking function, conforming to what is
-        expected at traversal"""
+    def _build_include_fn(
+        node_names: Dict[str, Dict[str, Any]],
+    ) -> Callable[[str, str, jax.Array], bool]:
+        """Build an equality-checking function for parameters' traversal."""
 
-        def include_fn(layer, name, value):  # pylint: disable=W0613
+        def include_fn(layer: str, name: str, value: jax.Array) -> bool:
+            # mandatory signature; pylint: disable=unused-argument
             if layer in list(node_names.keys()):
                 return name in list(node_names[layer].keys())
             return False
@@ -442,7 +442,7 @@ class HaikuModel(Model):
             next(self._rng_gen),
             inputs,
         )
-        # Flatten the gradients and return them in a Vector container
+        # Flatten the gradients and return them in a Vector container.
         flat_grad, _ = jax.tree_util.tree_flatten(grads)
         return JaxNumpyVector({str(k): v for k, v in enumerate(flat_grad)})
 
@@ -479,7 +479,7 @@ class HaikuModel(Model):
     def _clipped_grad_fn(
         self,
     ) -> Callable[
-        [hk.Params, hk.Params, jax.Array, JaxBatch, float], jax.Array
+        [hk.Params, hk.Params, jax.Array, JaxBatch, float], List[jax.Array]
     ]:
         """Lazy-built jax function to compute clipped sample-wise gradients.
 
@@ -524,6 +524,7 @@ class HaikuModel(Model):
                 return jnp.array(data)  # pylint: disable=no-member
             raise TypeError("HaikuModel requires numpy or jax.numpy data.")
 
+        # similar code to TorchModel; pylint: disable=duplicate-code
         # Convert batched data to jax Arrays.
         inputs, y_true, s_wght = batch
         if not isinstance(inputs, (tuple, list)):
@@ -544,7 +545,7 @@ class HaikuModel(Model):
     def compute_batch_predictions(
         self,
         batch: Batch,
-    ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray],]:
+    ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
         inputs, y_true, s_wght = self._unpack_batch(batch)
         if y_true is None:
             raise TypeError(
@@ -558,8 +559,7 @@ class HaikuModel(Model):
             self._transformed_model.apply(params, next(self._rng_gen), *inputs)
         )
         y_true = np.asarray(y_true)  # type: ignore
-        if isinstance(s_wght, jax.Array):
-            s_wght = np.asarray(s_wght)  # type: ignore
+        s_wght = None if s_wght is None else np.asarray(s_wght)  # type: ignore
         return y_true, y_pred, s_wght  # type: ignore
 
     def loss_function(
diff --git a/test/model/test_haiku.py b/test/model/test_haiku.py
index f9231ded..8083dcda 100644
--- a/test/model/test_haiku.py
+++ b/test/model/test_haiku.py
@@ -259,6 +259,7 @@ class TestHaikuModel(ModelTestSuite):
         criterion_type: str,
     ) -> None:
         """Check that `set_weights` behaves properly with frozen weights."""
+        # similar code to TorchModel tests; pylint: disable=duplicate-code
         # Setup a model with some frozen weights, and gather trainable ones.
         model = test_case.model
         criterion = test_case.get_trainable_criterion(criterion_type)
-- 
GitLab