Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

serialization.from_state_dict does not restore to jax.Arrays #3999

Open
PhilipVinc opened this issue Jun 15, 2024 · 1 comment
Open

serialization.from_state_dict does not restore to jax.Arrays #3999

PhilipVinc opened this issue Jun 15, 2024 · 1 comment

Comments

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Jun 15, 2024

Deserialising a single jax.Array with flax.serialization does not reconstruct jax.Arraycorrectly even when they are present in the target pytree.

See

import jax
import jax.numpy as jnp
import flax
from flax import serialization

pt = {'a':jnp.ones((3,4))}

bdata = serialization.to_bytes(pt)
pt_loaded = serialization.from_bytes(pt, bdata)

print("Standard Deserialization  output type:", type(pt_loaded['a']))
> Standard Deserialization output type: <class 'numpy.ndarray'>

This is because there is no rule for how to deal with jax.Arrays.

To fix it, one needs to register this rule

serialization.register_serialization_state(
  type(pt['a']),
  lambda x: x,
  lambda x, sd: jax.numpy.asarray(sd, dtype=x.dtype),
  override=True
)
bdata = serialization.to_bytes(pt)
pt_loaded = serialization.from_bytes(pt, bdata)

print("type Deserialization      output type:", type(pt_loaded['a']))
> type Deserialization      output type: <class 'jaxlib.xla_extension.ArrayImpl'>

Would it be possible to get this inside of flax itself?

@IvyZX
Copy link
Collaborator

IvyZX commented Jun 21, 2024

Thanks for the suggestion.
One concern we may have is that JAX sets its array type as private (jax._src.array.ArrayImpl) to keep the flexibility to modify its APIs, especially because JAX arrays now cover a bunch of device-specific metadata (like shardings, memory placement, etc).
On the other hand, flax.serialization treats an array only by its values, and no metadata like device shardings etc will be preserved when the bytes get deserialized later. Thus it fits better with being a NumPy array simply sitting in the host CPU memory.
It isn't hard to convert it into a naive unsharded JAX array - just jax.tree.map(jnp.array, pt_loaded) is enough. But it might be better to leave to the user to fill out its jax array metadata.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants