Add support for Jax / Haiku
Compare changes
declearn/model/haiku/utils/__init__.py
0 → 100644
+ 25
− 0
GitLab upgrade completed. Current version is 17.11.1. We now benefit from the features of the release 17.11.
Goal : implement HaikuModel and JaxNumpyVectors.
Must have
jax.tree_utils
jit
haiku
modules (Implement LSTM with unroll_net, add RNN to test)unpack_batch
Nice to haves :
JAX
:
State
class to shorten method signatures, using @chex.dataclass
@jit
whole methods while using instance attributes (source):reserve
argument (source)@chex.variant
to testingHaikuModel