Add support for Jax / Haiku
Goal : implement HaikuModel and JaxNumpyVectors.
Must have
-
Do MVP -
Create JaxNumpyVector -
Create HaikyModel, with a tax version of all abstract methods -
Create a minimal example for testing -
Deal with flattening and unflattening params, using jax.tree_utils
-
-
Refine MVP -
Add clipping : use optax-inspired clipping per sample -
Add jit
-
Add GPU support -
Add proper randomization -
Add typing -
Use the new DataInfoFIeld -
Properly document -
Rebase and lint -
Full test suite -
Add dependencies
-
-
Examine edge cases -
Deal with statefulness in haiku
modules (Implement LSTM with unroll_net, add RNN to test) -
Take a closer look at multi-input models (data_type at init, inputs as list ) : use unpack_batch
-
Look at frozen weights (don't forget to update test, see 'MLP-tune') -
Figure out why per-sample gradient calculations are so slow. -
Ensure BatchNorm works
-
Nice to haves :
-
Upgrade my JAX
:-
Reintroduce the State
class to shorten method signatures, using@chex.dataclass
-
Make custom class to @jit
whole methods while using instance attributes (source): -
Allow for external PRNGSequence initilaization, to allow optimization via the reserve
argument (source) -
explore using hk.Initializer -
Add @chex.variant
to testing
-
-
Enforce checks on model and loss function passed as input to HaikuModel
-
Refine loss : -
Accept optax losses, do the mapping on loss inputs (inspect.get_signature) -
Add kwargs to _compute_loss and loss_function to cover all cases of optax loss -
Look at mixed precision
-
-
Refine serialization : avoid using pickling.
Edited by BIGAUD Nathan