You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for the suggestion.
One concern we may have is that JAX sets its array type as private (jax._src.array.ArrayImpl) to keep the flexibility to modify its APIs, especially because JAX arrays now cover a bunch of device-specific metadata (like shardings, memory placement, etc).
On the other hand, flax.serialization treats an array only by its values, and no metadata like device shardings etc will be preserved when the bytes get deserialized later. Thus it fits better with being a NumPy array simply sitting in the host CPU memory.
It isn't hard to convert it into a naive unsharded JAX array - just jax.tree.map(jnp.array, pt_loaded) is enough. But it might be better to leave to the user to fill out its jax array metadata.
Deserialising a single jax.Array with
flax.serialization
does not reconstructjax.Array
correctly even when they are present in the target pytree.See
This is because there is no rule for how to deal with jax.Arrays.
To fix it, one needs to register this rule
Would it be possible to get this inside of flax itself?
The text was updated successfully, but these errors were encountered: