Mentions légales du service

Skip to content

Add support for Jax / Haiku

BIGAUD Nathan requested to merge haiku-jax into develop

Goal : implement HaikuModel and JaxNumpyVectors.

Must have

  • Do MVP
    • Create JaxNumpyVector
    • Create HaikyModel, with a tax version of all abstract methods
    • Create a minimal example for testing
    • Deal with flattening and unflattening params, using jax.tree_utils
  • Refine MVP
    • Add clipping : use optax-inspired clipping per sample
    • Add jit
    • Add GPU support
    • Add proper randomization
    • Add typing
    • Use the new DataInfoFIeld
    • Properly document
    • Rebase and lint
    • Full test suite
    • Add dependencies
  • Examine edge cases
    • Deal with statefulness in haiku modules (Implement LSTM with unroll_net, add RNN to test)
    • Take a closer look at multi-input models (data_type at init, inputs as list ) : use unpack_batch
    • Look at frozen weights (don't forget to update test, see 'MLP-tune')
    • Figure out why per-sample gradient calculations are so slow.
    • Ensure BatchNorm works

Nice to haves :

  • Upgrade my JAX :
    • Reintroduce the State class to shorten method signatures, using @chex.dataclass
    • Make custom class to @jit whole methods while using instance attributes (source):
    • Allow for external PRNGSequence initilaization, to allow optimization via the reserve argument (source)
    • explore using hk.Initializer
    • Add @chex.variant to testing
  • Enforce checks on model and loss function passed as input to HaikuModel
  • Refine loss :
    • Accept optax losses, do the mapping on loss inputs (inspect.get_signature)
    • Add kwargs to _compute_loss and loss_function to cover all cases of optax loss
    • Look at mixed precision
  • Refine serialization : avoid using pickling.
Edited by BIGAUD Nathan

Merge request reports