Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

static_argnums argument to flax.linen.remat not working as expected #3946

Open
dionhaefner opened this issue May 28, 2024 · 2 comments
Open
Assignees
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@dionhaefner
Copy link

I am using flax.linen.remat on a module that has a train flag (used to check if the model is training). I'm using static_argnums on that flag, but am still getting a ConcretizationTypeError on model init.

Reproducer:

import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x, train):
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x

# Works fine with this line commented out
MLP = nn.remat(MLP, static_argnums=(1,))

model = MLP()
rng_key = jax.random.PRNGKey(42)
variables = model.init(rng_key, input_example, True)

Traceback:

$ python foo.py
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/dion/codes/supersede/foo.py", line 43, in <module>
    variables = model.init(rng_key, input_example, True)
jax.errors.ConcretizationTypeError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function new_fun at /Users/dion/.virtualenvs/science/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py:393 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument dyn_args[2].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Consider using the `static_argnums` parameter for `jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` docstring and its example involving `static_argnums`:
https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html

Tested with flax==0.8.4.

@cgarciae
Copy link
Collaborator

train is the 3rd argument so you have to change the static_argnums like this:

MLP = nn.remat(MLP, static_argnums=(2,))

@cgarciae cgarciae self-assigned this May 29, 2024
@cgarciae cgarciae added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label May 29, 2024
@dionhaefner
Copy link
Author

I see. I guess I got confused because sometimes our models are used like this:

model.apply(variables, inputs, train=False)

which triggers this error:

ValueError: the `static_argnums` argument to `jax.checkpoint` / `jax.remat` can only take integer values greater than or equal to `-len(args)` and less than `len(args)`, but got (3,)

So I assumed it wasn't counting the self argument. Any chance we could support something akin to static_argnames from jax.jit to support kwargs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

2 participants