Skip to content

v0.8.5

Latest
Compare
Choose a tag to compare
@cgarciae cgarciae released this 26 Jun 09:27
· 27 commits to main since this release

What's Changed

  • v0.8.5 by @cgarciae in #3941
  • [nnx] improve vmap axis size detection by @cgarciae in #3947
  • Add direct penzai.treescope support for NNX objects. by @copybara-service in #3948
  • [nnx] fix nnx_basics dependencies by @cgarciae in #3942
  • Rename all the NNX tests to internal naming & build conventions. by @copybara-service in #3952
  • updated rng guide by @chiamp in #3912
  • upgraded haiku guide to include NNX by @chiamp in #3923
  • parameterized NNX transforms tests by @chiamp in #3906
  • Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes. by @copybara-service in #3957
  • fix HEAD by @chiamp in #3960
  • Minor grammar fixes to NNX documentation. by @mcsmart76 in #3953
  • Make FlatState a Mapping instead of a dict by @NeilGirdhar in #3928
  • Adding Welford metric. by @copybara-service in #3959
  • Modify Welford metric to return mean value. by @copybara-service in #3970
  • [nnx] make State generic by @cgarciae in #3964
  • updated NNX nn docstrings by @chiamp in #3972
  • make flax work with upcoming JAX change to tree_map (being more careful about by @copybara-service in #3976
  • updated nnx.module docstrings by @chiamp in #3966
  • updated nnx.Conv and nnx.ConvTranspose by @chiamp in #3974
  • updated nnx.graph docstrings by @chiamp in #3958
    • Adds pmap and Pmap. static_broadcasted_argnums, donate_argnums, and global_arg_shapes are not yet supported. by @copybara-service in #3978
  • Fixes for batch norm docs by @jkarwowski in #3982
  • fix deprecation warning by @chiamp in #3981
  • updated NNX rnglib docstring by @chiamp in #3980
  • updated nnx.training by @chiamp in #3975
  • updated nnx.variables docstrings by @chiamp in #3986
  • [nnx] vectorize vmap split counts by @cgarciae in #3989
  • added wrt option to nnx.Optimizer by @chiamp in #3983
  • Added nnx.graph.iter_children by @chiamp in #3991
  • [nnx] fix vmap by @copybara-service in #3995
  • Fix head pytest breakage by @IvyZX in #4006
  • Helper function for loading params from a linen module by @copybara-service in #4012
  • Port gemma/layers to NNX by @copybara-service in #4013
  • [nnx] fix grad by @cgarciae in #4007
  • [nnx] add PathContains Filter by @cgarciae in #4011
  • Support Python 3.9 by @copybara-service in #4018
  • Port gemma/modules to NNX by @copybara-service in #4014
  • Internal change to fix current head CI by @copybara-service in #4017
  • Unpin the Orbax pip version. by @copybara-service in #4024
  • Fix Gemma test to unbreak head by @IvyZX in #4025
  • Fix pickling of exceptions by @sanderland in #4002
  • Call user-defined variable transforms before determining axis size in nn.vmap. by @copybara-service in #4026
  • CI: add test run against oldest supported jax version by @jakevdp in #3996
  • Make force_fp32_for_softmax arg in MultiHeadDotProductAttention useful. by @copybara-service in #4029

New Contributors

Full Changelog: v0.8.4...v0.8.5