Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Best practice of dealing with sporadic FrozenDict conversions? #3994

Open
TimSchneider42 opened this issue Jun 13, 2024 · 2 comments
Open

Best practice of dealing with sporadic FrozenDict conversions? #3994

TimSchneider42 opened this issue Jun 13, 2024 · 2 comments

Comments

@TimSchneider42
Copy link

Hi,

in my project, I have multiple instances of modules that look approximately like this:

class TreeEncoder(nn.Module):
    leaf_encoders: Any  # pytree of nn.Modules

    @nn.compact
    def __call__(self, data: Any):  # data is a pytree of jax.Array
        return jnp.stack(jax.tree_leaves(jax.tree_map(lambda d, enc: enc(d), data, self.leaf_encoders)), axis=-1)

So essentially, this module has one encoder for every leaf of the input pytree and uses them to obtain a vector encoding of the entire tree. Usage could look like this:

encoder = TreeEncoder({"img": ResNet(...), "vector": DenseNN(...)})
output = encoder({"img": jnp.zeros((480, 640, 3)), "vector": jnp.zeros(5)})

The beauty of JAX is that the pytree can be an arbitrary structure and not only dicts are possible. I make heavy use of this fact and sometimes just define my own dataclasses.

But here comes my problem: if I use a dictionary, flax will sometimes convert them into FrozenDicts and then the call to jax.tree_map fails with

Custom node type mismatch: expected type: <class 'flax.core.frozen_dict.FrozenDict'>, value: {'img': ..., 'vector': ...}

It took me a while to understand when these conversions happen, but I am fairly certain now that flax behaves as follows:

  1. The inputs to __call__ are never converted (data is always a dict)
  2. The leaf_encoders are converted iff the TreeEncoder instance is created inside a @nn.compact call

How do I deal with this? I cannot call self.leaf_encoders.unfreeze() because it might not be a FrozenDict (and not even a dictionary). Is there some way I can disable the FrozenDict conversion in general? Or is it possible to make FrozenDicts and dicts compatible as arguments to jax.tree_map?

Thanks a lot in advance!

Best,
Tim

@TimSchneider42 TimSchneider42 changed the title Best practice of dealing with sporadic FrozenDicts conversion? Best practice of dealing with sporadic FrozenDict conversions? Jun 13, 2024
@IvyZX
Copy link
Collaborator

IvyZX commented Jun 21, 2024

Hmm, are you using the latest Flax? The latest Flax should have flax.config.flax_return_frozendict as False and return normal dict by default.

In other case, you can do an explicit check like hasattr(x, unfreeze) and call x.unfreeze() only when true. But please let me know if you are using latest Flax and still run into this problem.

cc @chiamp

@TimSchneider42
Copy link
Author

Hi,

I have to double check which flax version I am using, but I installed it from PyPi around 2 months ago, so it should be fairly recent.

To be clear, my problem is not that modules return FrozenDicts, but rather that fields of modules sometimes get converted to FrozenDicts. I think I will go with your suggestion for now, but an option to turn that behavior off fully would be nice for the future.

Best,
Tim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants