You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
in my project, I have multiple instances of modules that look approximately like this:
classTreeEncoder(nn.Module):
leaf_encoders: Any# pytree of nn.Modules@nn.compactdef__call__(self, data: Any): # data is a pytree of jax.Arrayreturnjnp.stack(jax.tree_leaves(jax.tree_map(lambdad, enc: enc(d), data, self.leaf_encoders)), axis=-1)
So essentially, this module has one encoder for every leaf of the input pytree and uses them to obtain a vector encoding of the entire tree. Usage could look like this:
The beauty of JAX is that the pytree can be an arbitrary structure and not only dicts are possible. I make heavy use of this fact and sometimes just define my own dataclasses.
But here comes my problem: if I use a dictionary, flax will sometimes convert them into FrozenDicts and then the call to jax.tree_map fails with
It took me a while to understand when these conversions happen, but I am fairly certain now that flax behaves as follows:
The inputs to __call__ are never converted (data is always a dict)
The leaf_encoders are converted iff the TreeEncoder instance is created inside a @nn.compact call
How do I deal with this? I cannot call self.leaf_encoders.unfreeze() because it might not be a FrozenDict (and not even a dictionary). Is there some way I can disable the FrozenDict conversion in general? Or is it possible to make FrozenDicts and dicts compatible as arguments to jax.tree_map?
Thanks a lot in advance!
Best,
Tim
The text was updated successfully, but these errors were encountered:
TimSchneider42
changed the title
Best practice of dealing with sporadic FrozenDicts conversion?
Best practice of dealing with sporadic FrozenDict conversions?
Jun 13, 2024
Hmm, are you using the latest Flax? The latest Flax should have flax.config.flax_return_frozendict as False and return normal dict by default.
In other case, you can do an explicit check like hasattr(x, unfreeze) and call x.unfreeze() only when true. But please let me know if you are using latest Flax and still run into this problem.
I have to double check which flax version I am using, but I installed it from PyPi around 2 months ago, so it should be fairly recent.
To be clear, my problem is not that modules return FrozenDicts, but rather that fields of modules sometimes get converted to FrozenDicts. I think I will go with your suggestion for now, but an option to turn that behavior off fully would be nice for the future.
Hi,
in my project, I have multiple instances of modules that look approximately like this:
So essentially, this module has one encoder for every leaf of the input pytree and uses them to obtain a vector encoding of the entire tree. Usage could look like this:
The beauty of JAX is that the pytree can be an arbitrary structure and not only dicts are possible. I make heavy use of this fact and sometimes just define my own dataclasses.
But here comes my problem: if I use a dictionary, flax will sometimes convert them into FrozenDicts and then the call to jax.tree_map fails with
It took me a while to understand when these conversions happen, but I am fairly certain now that flax behaves as follows:
__call__
are never converted (data is always a dict)@nn.compact
callHow do I deal with this? I cannot call
self.leaf_encoders.unfreeze()
because it might not be a FrozenDict (and not even a dictionary). Is there some way I can disable the FrozenDict conversion in general? Or is it possible to make FrozenDicts and dicts compatible as arguments tojax.tree_map
?Thanks a lot in advance!
Best,
Tim
The text was updated successfully, but these errors were encountered: