Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit d02b880e authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Minor formatting revisions to 'HaikuModel'.

parent defe9ba6
No related branches found
No related tags found
1 merge request!32Add support for Jax / Haiku
...@@ -18,20 +18,14 @@ ...@@ -18,20 +18,14 @@
"""Model subclass to wrap Haiku models.""" """Model subclass to wrap Haiku models."""
import functools import functools
import inspect
import io import io
import warnings import warnings
from copy import deepcopy from copy import deepcopy
from random import SystemRandom from random import SystemRandom
from typing import ( from typing import (
Any, # fmt: off
Callable, Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
Dict,
Generator,
List,
Optional,
Set,
Tuple,
Union,
) )
import haiku as hk import haiku as hk
...@@ -209,7 +203,7 @@ class HaikuModel(Model): ...@@ -209,7 +203,7 @@ class HaikuModel(Model):
def get_named_weights( def get_named_weights(
self, self,
trainable: bool = False, trainable: bool = False,
) -> Any: ) -> Dict[str, Dict[str, jax.Array]]:
"""Access the weights of the haiku model as a nested dict. """Access the weights of the haiku model as a nested dict.
Return type is any due to `jax.tree_util.tree_unflatten`. Return type is any due to `jax.tree_util.tree_unflatten`.
...@@ -217,7 +211,6 @@ class HaikuModel(Model): ...@@ -217,7 +211,6 @@ class HaikuModel(Model):
trainable: bool, default=False trainable: bool, default=False
If True, restrict the returned weights to the trainable ones, If True, restrict the returned weights to the trainable ones,
else return all weights. else return all weights.
""" """
assert self._treedef is not None, "uninitialized JaxModel" assert self._treedef is not None, "uninitialized JaxModel"
params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves) params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves)
...@@ -231,8 +224,7 @@ class HaikuModel(Model): ...@@ -231,8 +224,7 @@ class HaikuModel(Model):
return params return params
def _get_fixed_named_weights(self) -> Any: def _get_fixed_named_weights(self) -> Any:
"""Access the fixed weights of the haiku model as a nested dict, """Access the fixed weights of the model as a nested dict, if any.
if any.
Return type is any due to `jax.tree_util.tree_unflatten`. Return type is any due to `jax.tree_util.tree_unflatten`.
""" """
...@@ -344,11 +336,13 @@ class HaikuModel(Model): ...@@ -344,11 +336,13 @@ class HaikuModel(Model):
if not self._initialized: if not self._initialized:
raise ValueError("Model needs to be initialized first") raise ValueError("Model needs to be initialized first")
if isinstance(criterion, list) and isinstance(criterion[0], int): if isinstance(criterion, list) and isinstance(criterion[0], int):
self._trainable = criterion # type: ignore self._trainable = criterion
else: else:
self._trainable = [] # reset if needed self._trainable = [] # reset if needed
if isinstance(criterion, Callable): # type: ignore if inspect.isfunction(criterion):
include_fn = criterion include_fn = (
criterion
) # type: Callable[[str, str, jax.Array], bool]
elif isinstance(criterion, dict): elif isinstance(criterion, dict):
include_fn = self._build_include_fn(criterion) include_fn = self._build_include_fn(criterion)
else: else:
...@@ -359,13 +353,17 @@ class HaikuModel(Model): ...@@ -359,13 +353,17 @@ class HaikuModel(Model):
for idx, (layer, name, value) in enumerate( for idx, (layer, name, value) in enumerate(
self._traverse_params() self._traverse_params()
): ):
if include_fn(layer, name, value): # type: ignore if include_fn(layer, name, value):
self._trainable.append(idx) self._trainable.append(idx)
def _traverse_params(self) -> Generator: def _traverse_params(self) -> Iterator[Tuple[str, str, jax.Array]]:
"""Traverse the pytree of model named weight left-to-right, """Traverse the pytree of a model's named weights.
depth-first, returning a generator"""
params = self.get_named_weights() Yield (layer_name, weight_name, weight_value) tuples from
traversing the pytree left-to-right, depth-first.
"""
assert self._treedef is not None, "uninitialized JaxModel"
params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves)
for layer in params: for layer in params:
bundle = params[layer] bundle = params[layer]
for name in bundle: for name in bundle:
...@@ -373,11 +371,13 @@ class HaikuModel(Model): ...@@ -373,11 +371,13 @@ class HaikuModel(Model):
yield layer, name, value yield layer, name, value
@staticmethod @staticmethod
def _build_include_fn(node_names: Dict[str, Dict[str, Any]]) -> Callable: def _build_include_fn(
"""Build an equality checking function, conforming to what is node_names: Dict[str, Dict[str, Any]],
expected at traversal""" ) -> Callable[[str, str, jax.Array], bool]:
"""Build an equality-checking function for parameters' traversal."""
def include_fn(layer, name, value): # pylint: disable=W0613 def include_fn(layer: str, name: str, value: jax.Array) -> bool:
# mandatory signature; pylint: disable=unused-argument
if layer in list(node_names.keys()): if layer in list(node_names.keys()):
return name in list(node_names[layer].keys()) return name in list(node_names[layer].keys())
return False return False
...@@ -442,7 +442,7 @@ class HaikuModel(Model): ...@@ -442,7 +442,7 @@ class HaikuModel(Model):
next(self._rng_gen), next(self._rng_gen),
inputs, inputs,
) )
# Flatten the gradients and return them in a Vector container # Flatten the gradients and return them in a Vector container.
flat_grad, _ = jax.tree_util.tree_flatten(grads) flat_grad, _ = jax.tree_util.tree_flatten(grads)
return JaxNumpyVector({str(k): v for k, v in enumerate(flat_grad)}) return JaxNumpyVector({str(k): v for k, v in enumerate(flat_grad)})
...@@ -479,7 +479,7 @@ class HaikuModel(Model): ...@@ -479,7 +479,7 @@ class HaikuModel(Model):
def _clipped_grad_fn( def _clipped_grad_fn(
self, self,
) -> Callable[ ) -> Callable[
[hk.Params, hk.Params, jax.Array, JaxBatch, float], jax.Array [hk.Params, hk.Params, jax.Array, JaxBatch, float], List[jax.Array]
]: ]:
"""Lazy-built jax function to compute clipped sample-wise gradients. """Lazy-built jax function to compute clipped sample-wise gradients.
...@@ -524,6 +524,7 @@ class HaikuModel(Model): ...@@ -524,6 +524,7 @@ class HaikuModel(Model):
return jnp.array(data) # pylint: disable=no-member return jnp.array(data) # pylint: disable=no-member
raise TypeError("HaikuModel requires numpy or jax.numpy data.") raise TypeError("HaikuModel requires numpy or jax.numpy data.")
# similar code to TorchModel; pylint: disable=duplicate-code
# Convert batched data to jax Arrays. # Convert batched data to jax Arrays.
inputs, y_true, s_wght = batch inputs, y_true, s_wght = batch
if not isinstance(inputs, (tuple, list)): if not isinstance(inputs, (tuple, list)):
...@@ -544,7 +545,7 @@ class HaikuModel(Model): ...@@ -544,7 +545,7 @@ class HaikuModel(Model):
def compute_batch_predictions( def compute_batch_predictions(
self, self,
batch: Batch, batch: Batch,
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray],]: ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
inputs, y_true, s_wght = self._unpack_batch(batch) inputs, y_true, s_wght = self._unpack_batch(batch)
if y_true is None: if y_true is None:
raise TypeError( raise TypeError(
...@@ -558,8 +559,7 @@ class HaikuModel(Model): ...@@ -558,8 +559,7 @@ class HaikuModel(Model):
self._transformed_model.apply(params, next(self._rng_gen), *inputs) self._transformed_model.apply(params, next(self._rng_gen), *inputs)
) )
y_true = np.asarray(y_true) # type: ignore y_true = np.asarray(y_true) # type: ignore
if isinstance(s_wght, jax.Array): s_wght = None if s_wght is None else np.asarray(s_wght) # type: ignore
s_wght = np.asarray(s_wght) # type: ignore
return y_true, y_pred, s_wght # type: ignore return y_true, y_pred, s_wght # type: ignore
def loss_function( def loss_function(
......
...@@ -259,6 +259,7 @@ class TestHaikuModel(ModelTestSuite): ...@@ -259,6 +259,7 @@ class TestHaikuModel(ModelTestSuite):
criterion_type: str, criterion_type: str,
) -> None: ) -> None:
"""Check that `set_weights` behaves properly with frozen weights.""" """Check that `set_weights` behaves properly with frozen weights."""
# similar code to TorchModel tests; pylint: disable=duplicate-code
# Setup a model with some frozen weights, and gather trainable ones. # Setup a model with some frozen weights, and gather trainable ones.
model = test_case.model model = test_case.model
criterion = test_case.get_trainable_criterion(criterion_type) criterion = test_case.get_trainable_criterion(criterion_type)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment