diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py index 63924478dbcbf8487fd82ef1bb474856a494c113..beb5c36024032d18d8b96d721b9e34843bfd01ef 100644 --- a/declearn/model/haiku/_model.py +++ b/declearn/model/haiku/_model.py @@ -21,7 +21,6 @@ import functools import inspect import io import warnings -from copy import deepcopy from random import SystemRandom from typing import ( # fmt: off @@ -109,11 +108,11 @@ class HaikuModel(Model): # Select the device where to place computations. policy = get_device_policy() self._device = select_device(gpu=policy.gpu, idx=policy.idx) - # Create model state attributes - self._pleaves = [] # type: List[jax.Array] - self._treedef = None # type: Optional[jax.tree_util.PyTreeDef] - self._trainable = [] # type: List[int] - # Initialize the PRNG + # Create model state attributes. + self._params = {} # type: hk.Params + self._pnames = [] # type: List[str] + self._trainable = [] # type: List[str] + # Initialize the PRNG. if seed is None: seed = int(SystemRandom().random() * 10e6) self._rng_gen = hk.PRNGSequence(seed) @@ -144,18 +143,20 @@ class HaikuModel(Model): self.data_info = aggregate_data_info( [data_info], self.required_data_info ) - # initialize. + # Initialize the model parameters. params = self._transformed_model.init( next(self._rng_gen), jnp.zeros( (1, *data_info["features_shape"]), data_info["data_type"] ), ) - params = jax.device_put(params, self._device) - pleaves, treedef = jax.tree_util.tree_flatten(params) - self._treedef = treedef - self._pleaves = pleaves - self._trainable = list(range(len(pleaves))) + self._params = jax.device_put(params, self._device) + self._pnames = [ + f"{layer}:{weight}" + for layer, weights in self._params.items() + for weight in weights + ] + self._trainable = self._pnames.copy() self._initialized = True def get_config( @@ -195,51 +196,15 @@ class HaikuModel(Model): self, trainable: bool = False, ) -> JaxNumpyVector: - params = {str(k): v for k, v in enumerate(self._pleaves)} + params = { + f"{layer}:{wname}": value + for layer, weights in self._params.items() + for wname, value in weights.items() + } if trainable: - params = {str(idx): params[str(idx)] for idx in self._trainable} + params = {k: v for k, v in params.items() if k in self._trainable} return JaxNumpyVector(params) - def get_named_weights( - self, - trainable: bool = False, - ) -> Dict[str, Dict[str, jax.Array]]: - """Access the weights of the haiku model as a nested dict. - - Return type is any due to `jax.tree_util.tree_unflatten`. - - trainable: bool, default=False - If True, restrict the returned weights to the trainable ones, - else return all weights. - """ - assert self._treedef is not None, "uninitialized JaxModel" - params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves) - if trainable: - pop_idx = set(range(len(self._pleaves))) - set(self._trainable) - for i, (layer, name, _) in enumerate(self._traverse_params()): - if i in pop_idx: - params[layer].pop(name) - if len(params[layer]) == 0: - params.pop(layer) - return params - - def _get_fixed_named_weights(self) -> Any: - """Access the fixed weights of the model as a nested dict, if any. - - Return type is any due to `jax.tree_util.tree_unflatten`. - """ - assert self._treedef is not None, "uninitialized JaxModel" - if len(self._trainable) == len(self._pleaves): - return {} - params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves) - pop_idx = set(self._trainable) - for i, (layer, name, _) in enumerate(self._traverse_params()): - if i in pop_idx: - params[layer].pop(name) - if len(params[layer]) == 0: - params.pop(layer) - return params - def set_weights( # type: ignore # Vector subtype specification self, weights: JaxNumpyVector, @@ -248,16 +213,9 @@ class HaikuModel(Model): if not isinstance(weights, JaxNumpyVector): raise TypeError("HaikuModel requires JaxNumpyVector weights.") self._verify_weights_compatibility(weights, trainable=trainable) - coefs_copy = deepcopy(weights.coefs) - if trainable: - for idx in self._trainable: - self._pleaves[idx] = jax.device_put( - coefs_copy[str(idx)], self._device - ) - else: - self._pleaves = [ - jax.device_put(v, self._device) for v in coefs_copy.values() - ] + for key, val in weights.coefs.items(): + layer, weight = key.rsplit(":", 1) + self._params[layer][weight] = val.copy() # type: ignore def _verify_weights_compatibility( self, @@ -282,10 +240,7 @@ class HaikuModel(Model): are present. Be verbose about the identified mismatch(es). """ received = set(vector.coefs) - if trainable: - expected = {str(i) for i in self._trainable} - else: - expected = {str(i) for i in range(len(self._pleaves))} + expected = set(self._trainable if trainable else self._pnames) raise_on_stringsets_mismatch( received, expected, context="model weights" ) @@ -295,7 +250,7 @@ class HaikuModel(Model): criterion: Union[ Callable[[str, str, jax.Array], bool], Dict[str, Dict[str, Any]], - List[int], + List[str], ], ) -> None: """Sets the index of trainable weights. @@ -335,7 +290,11 @@ class HaikuModel(Model): """ if not self._initialized: raise ValueError("Model needs to be initialized first") - if isinstance(criterion, list) and isinstance(criterion[0], int): + if ( + isinstance(criterion, list) + and all(isinstance(c, str) for c in criterion) + and all(c in self._pnames for c in criterion) + ): self._trainable = criterion else: self._trainable = [] # reset if needed @@ -347,14 +306,12 @@ class HaikuModel(Model): include_fn = self._build_include_fn(criterion) else: raise TypeError( - "The provided criterion does not conform" - "to the expected format and or type" + "The provided criterion does not conform " + "to the expected format and or type." ) - for idx, (layer, name, value) in enumerate( - self._traverse_params() - ): + for layer, name, value in self._traverse_params(): if include_fn(layer, name, value): - self._trainable.append(idx) + self._trainable.append(f"{layer}:{name}") def _traverse_params(self) -> Iterator[Tuple[str, str, jax.Array]]: """Traverse the pytree of a model's named weights. @@ -362,13 +319,11 @@ class HaikuModel(Model): Yield (layer_name, weight_name, weight_value) tuples from traversing the pytree left-to-right, depth-first. """ - assert self._treedef is not None, "uninitialized JaxModel" - params = jax.tree_util.tree_unflatten(self._treedef, self._pleaves) - for layer in params: - bundle = params[layer] - for name in bundle: - value = bundle[name] - yield layer, name, value + yield from ( + (layer, weight, value) + for layer, weights in self._params.items() + for weight, value in weights.items() + ) @staticmethod def _build_include_fn( @@ -384,6 +339,57 @@ class HaikuModel(Model): return include_fn + def get_weight_names( + self, + trainable: bool = False, + ) -> List[str]: + """Return the list of names of the model's weights. + + Parameters + ---------- + trainable: bool + Whether to return only the names of trainable weights, + rather than including both trainable and frozen ones. + + Returns + ------- + names: + Ordered list of model weights' names. + """ + return self._trainable.copy() if trainable else self._pnames.copy() + + def compute_batch_gradients( + self, + batch: Batch, + max_norm: Optional[float] = None, + ) -> JaxNumpyVector: + # Unpack input batch and prepare model parameters. + inputs = self._unpack_batch(batch) + train_params, fixed_params = hk.data_structures.partition( + predicate=lambda l, w, _: f"{l}:{w}" in self._trainable, + structure=self._params, + ) + rng = next(self._rng_gen) + # Compute batch-averaged gradients, opt. clipped on a per-sample basis. + if max_norm: + grads = self._clipped_grad_fn( + train_params, fixed_params, rng, inputs, max_norm + ) + grads = [value.mean(0) for value in grads] + else: + grads = jax.tree_util.tree_leaves( + self._grad_fn(train_params, fixed_params, rng, inputs) + ) + # Return the gradients, flattened into a JaxNumpyVector container. + return JaxNumpyVector(dict(zip(self._trainable, grads))) + + @functools.cached_property + def _grad_fn( + self, + ) -> Callable[[hk.Params, hk.Params, jax.Array, JaxBatch], hk.Params]: + """Lazy-built jax function to compute batch-averaged gradients.""" + return jax.jit(jax.grad(self._forward)) + def _forward( self, train_params: hk.Params, @@ -399,7 +405,7 @@ class HaikuModel(Model): The model parameters, as a nested dict of jax arrays. rng: jax.Array A jax pseudo-random number generator (PRNG) key. - batch: (jax.Array, jax.Array, optional[jax.Array]) + batch: (list[jax.Array], jax.Array, optional[jax.Array]) Batch of jax-converted inputs, comprising (in that order) input data, ground-truth labels and optional sample weights. @@ -408,7 +414,6 @@ class HaikuModel(Model): loss: jax.Array The mean loss over the input data provided. """ - # FUTURE: add support for lists of inputs inputs, y_true, s_wght = batch params = hk.data_structures.merge(train_params, fixed_params) y_pred = self._transformed_model.apply(params, rng, *inputs) @@ -417,64 +422,6 @@ class HaikuModel(Model): s_loss = s_loss * s_wght return jnp.mean(s_loss) - def compute_batch_gradients( - self, - batch: Batch, - max_norm: Optional[float] = None, - ) -> JaxNumpyVector: - if max_norm: - return self._compute_clipped_gradients(batch, max_norm) - return self._compute_batch_gradients(batch) - - def _compute_batch_gradients( - self, - batch: Batch, - ) -> JaxNumpyVector: - """Compute and return batch-averaged gradients of trainable weights.""" - # Unpack input batch and unflatten model parameters. - inputs = self._unpack_batch(batch) - train_params = self.get_named_weights(trainable=True) - fixed_params = self._get_fixed_named_weights() - # Run the forward and backward passes to compute gradients. - grads = self._grad_fn( - train_params, - fixed_params, - next(self._rng_gen), - inputs, - ) - # Flatten the gradients and return them in a Vector container. - flat_grad, _ = jax.tree_util.tree_flatten(grads) - return JaxNumpyVector({str(k): v for k, v in enumerate(flat_grad)}) - - @functools.cached_property - def _grad_fn( - self, - ) -> Callable[[hk.Params, hk.Params, jax.Array, JaxBatch], hk.Params]: - """Lazy-built jax function to compute batch-averaged gradients.""" - return jax.jit(jax.grad(self._forward)) - - def _compute_clipped_gradients( - self, - batch: Batch, - max_norm: float, - ) -> JaxNumpyVector: - """Compute and return sample-wise clipped, batch-averaged gradients.""" - # Unpack input batch and unflatten model parameters. - inputs = self._unpack_batch(batch) - train_params = self.get_named_weights(trainable=True) - fixed_params = self._get_fixed_named_weights() - # Get flattened, per-sample, clipped gradients and aggregate them. - clipped_grads = self._clipped_grad_fn( - train_params, - fixed_params, - next(self._rng_gen), - inputs, - max_norm, - ) - grads = [g.sum(0) for g in clipped_grads] - # Return them in a Vector container. - return JaxNumpyVector({str(k): v for k, v in enumerate(grads)}) - @functools.cached_property def _clipped_grad_fn( self, @@ -484,7 +431,8 @@ class HaikuModel(Model): """Lazy-built jax function to compute clipped sample-wise gradients. Note : The vmap in_axis parameters work thank to the jax feature of - applying optional parameters to pytrees.""" + applying optional parameters to pytrees. + """ def clipped_grad_fn( train_params: hk.Params, @@ -497,10 +445,7 @@ class HaikuModel(Model): inputs, y_true, s_wght = batch batch = (inputs, y_true, None) grads = jax.grad(self._forward)( - train_params, - fixed_params, - rng, - batch, + train_params, fixed_params, rng, batch ) grads_flat = [ grad / jnp.maximum(jnp.linalg.norm(grad) / max_norm, 1.0) @@ -539,8 +484,9 @@ class HaikuModel(Model): if not isinstance(updates, JaxNumpyVector): raise TypeError("HaikuModel requires JaxNumpyVector updates.") self._verify_weights_compatibility(updates, trainable=True) - for key, upd in updates.coefs.items(): - self._pleaves[int(key)] += upd + for key, val in updates.coefs.items(): + layer, weight = key.rsplit(":", 1) + self._params[layer][weight] += val # type: ignore def compute_batch_predictions( self, @@ -554,13 +500,14 @@ class HaikuModel(Model): "correct the inputs, or override this method to support " "creating labels from the base inputs." ) - params = self.get_named_weights() - y_pred = np.asarray( - self._transformed_model.apply(params, next(self._rng_gen), *inputs) + y_pred = self._transformed_model.apply( + self._params, next(self._rng_gen), *inputs + ) + return ( + np.asarray(y_true), + np.asarray(y_pred), + None if s_wght is None else np.asarray(s_wght), ) - y_true = np.asarray(y_true) # type: ignore - s_wght = None if s_wght is None else np.asarray(s_wght) # type: ignore - return y_true, y_pred, s_wght # type: ignore def loss_function( self, @@ -582,4 +529,4 @@ class HaikuModel(Model): # When needed, re-create the model to force moving it to the device. if self._device is not device: self._device = device - self._pleaves = jax.device_put(self._pleaves, self._device) + self._params = jax.device_put(self._params, self._device) diff --git a/test/model/test_haiku.py b/test/model/test_haiku.py index 8083dcda507280d5e7110490e7a07fe0f327a387..a2b8a2c6bf0dcdfca20dbaee05ad63c0379222b5 100644 --- a/test/model/test_haiku.py +++ b/test/model/test_haiku.py @@ -18,7 +18,7 @@ """Unit tests for HaikuModel.""" import sys -from typing import Any, List, Literal, Tuple +from typing import Any, Callable, Dict, List, Literal, Union import numpy as np import pytest @@ -176,7 +176,8 @@ class HaikuTestCase(ModelTestCase): } ) if self.kind == "MLP-tune": - model.set_trainable_weights([0, 1, 2]) + names = model.get_weight_names() + model.set_trainable_weights([names[i] for i in range(3)]) return model def assert_correct_device( @@ -190,17 +191,24 @@ class HaikuTestCase(ModelTestCase): for arr in vector.coefs.values() ) - def get_trainable_criterion(self, c_type: str): + def get_trainable_criterion( + self, + c_type: str, + ) -> Union[ + List[str], + Dict[str, Dict[str, Any]], + Callable[[str, str, jax.Array], bool], + ]: "Build different weight freezing criteria" - if c_type == "indexes": - crit = [2, 3] + if c_type == "names": + names = self.model.get_weight_names() + return [names[2], names[3]] if c_type == "pytree": - params = self.model.get_named_weights() - params.pop(list(params.keys())[1]) - crit = params + params = getattr(self.model, "_params") + return {k: v for i, (k, v) in enumerate(params.items()) if i != 1} if c_type == "predicate": - crit = lambda m, n, p: n != "b" # pylint: disable=C3001 - return crit + return lambda m, n, p: n != "b" + raise KeyError(f"Invalid 'c_type' parameter: {c_type}.") @pytest.fixture(name="test_case") @@ -230,7 +238,7 @@ class TestHaikuModel(ModelTestSuite): super().test_serialization(test_case) @pytest.mark.parametrize( - "criterion_type", ["indexes", "pytree", "predicate"] + "criterion_type", ["names", "pytree", "predicate"] ) def test_get_frozen_weights( self, @@ -244,14 +252,14 @@ class TestHaikuModel(ModelTestSuite): w_all = model.get_weights() w_trn = model.get_weights(trainable=True) assert set(w_trn.coefs).issubset(w_all.coefs) # check on keys - n_params = len(model._pleaves) # pylint: disable=protected-access - n_trainable = len(model._trainable) # pylint: disable=protected-access + n_params = len(model.get_weight_names()) + n_trainable = len(model.get_weight_names(trainable=True)) assert n_trainable < n_params assert len(w_trn.coefs) == n_trainable assert len(w_all.coefs) == n_params @pytest.mark.parametrize( - "criterion_type", ["indexes", "pytree", "predicate"] + "criterion_type", ["names", "pytree", "predicate"] ) def test_set_frozen_weights( self, @@ -281,7 +289,7 @@ class TestHaikuModel(ModelTestSuite): policy = model.device_policy assert policy.gpu == (test_case.device == "gpu") assert policy.idx == 0 - params = getattr(model, "_pleaves") + params = jax.tree_util.tree_leaves(getattr(model, "_params")) device = f"{test_case.device}:0" for arr in params: assert f"{arr.device().platform}:{arr.device().id}" == device