diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py index 5442e59abf8d1bf46258fc4226c40941b15b78c7..69f3e34f7c21483901a0acdc8c4a822331126ae4 100644 --- a/declearn/model/haiku/_model.py +++ b/declearn/model/haiku/_model.py @@ -55,7 +55,7 @@ __all__ = [ # alias for unpacked Batch structures, converted to jax arrays # FUTURE: add support for lists of inputs -JaxBatch = Tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]] +JaxBatch = Tuple[List[jax.Array], Optional[jax.Array], Optional[jax.Array]] @register_type(name="HaikuModel", group="Model") @@ -334,7 +334,7 @@ class HaikuModel(Model): must be a function taking in the name of the module (e.g. layer name), the element name (e.g. parameter name) and the corresponding data and returning a boolean. See - [the haiku doc](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.data_structures.partition) # noqa + [the haiku doc](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.data_structures.partition) for details. If a list of integers, should represent the index of trainable parameters in the parameter tree leaves. If a dict, should be formatted as a pytree. @@ -410,7 +410,7 @@ class HaikuModel(Model): # FUTURE: add support for lists of inputs inputs, y_true, s_wght = batch params = hk.data_structures.merge(train_params, fixed_params) - y_pred = self._transformed_model.apply(params, rng, inputs) + y_pred = self._transformed_model.apply(params, rng, *inputs) s_loss = self._loss_fn(y_pred, y_true) # type: ignore if s_wght is not None: s_loss = s_loss * s_wght @@ -480,7 +480,10 @@ class HaikuModel(Model): ) -> Callable[ [hk.Params, hk.Params, jax.Array, JaxBatch, float], jax.Array ]: - """Lazy-built jax function to compute clipped sample-wise gradients.""" + """Lazy-built jax function to compute clipped sample-wise gradients. + + Note : The vmap in_axis parameters work thank to the jax feature of + applying optional parameters to pytrees.""" def clipped_grad_fn( train_params: hk.Params, @@ -514,13 +517,6 @@ class HaikuModel(Model): """Unpack and enforce jnp.array conversion to an input data batch.""" def convert(data: Any) -> Optional[jax.Array]: - if isinstance(data, (list, tuple)): - if len(data) == 1: - data = data[0] - else: - raise TypeError( - "HaikuModel does not support multi-arrays inputs." - ) if (data is None) or isinstance(data, jax.Array): return data if isinstance(data, np.ndarray): @@ -529,7 +525,9 @@ class HaikuModel(Model): # Convert batched data to jax Arrays. inputs, y_true, s_wght = batch - output = [convert(inputs), convert(y_true), convert(s_wght)] + if not isinstance(inputs, (tuple, list)): + inputs = [inputs] + output = [list(map(convert, inputs)), convert(y_true), convert(s_wght)] return output # type: ignore def apply_updates( # type: ignore # Vector subtype specification @@ -556,7 +554,7 @@ class HaikuModel(Model): ) params = self.get_named_weights() y_pred = np.asarray( - self._transformed_model.apply(params, next(self._rng_gen), inputs) + self._transformed_model.apply(params, next(self._rng_gen), *inputs) ) y_true = np.asarray(y_true) # type: ignore if isinstance(s_wght, jax.Array):