Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lstm error #4032

Open
layssi opened this issue Jun 26, 2024 · 1 comment
Open

lstm error #4032

layssi opened this issue Jun 26, 2024 · 1 comment

Comments

@layssi
Copy link

layssi commented Jun 26, 2024

import flax.linen as nn
import jax, jax.numpy as jnp

x = jax.random.normal(jax.random.key(0), (2, 3))
layer = nn.LSTMCell(features=4)
carry = layer.initialize_carry(jax.random.key(1), x.shape)
variables = layer.init(jax.random.key(2), carry, x)
new_carry, out = layer.apply(variables, carry, x)

Running the code gives this error. This code comes from the documentation

flax.errors.AssignSubModuleError: Submodule LSTMCell must be defined in setup() or in a method wrapped in @compact (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.AssignSubModuleError)

@rajasekharporeddy
Copy link
Contributor

Hi @layssi

I tested the mentioned code on colab CPU, GPU and TPU v2 with JAX 0.4.26 and Flax 0.8.4. Also tested on Macbook CPU with JAX version 0.4.30 and Flax 0.8.5. I could not reproduce the error that you mentioned and it works fine.

Please find the gist for reference.

Could you please verify if the issue persists with latest versions of JAX and Flax?

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants