You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am using flax.linen.remat on a module that has a train flag (used to check if the model is training). I'm using static_argnums on that flag, but am still getting a ConcretizationTypeError on model init.
Reproducer:
importjaximportjax.numpyasjnpimportflax.linenasnnclassMLP(nn.Module):
@nn.compactdef__call__(self, x, train):
x=nn.BatchNorm(use_running_average=nottrain)(x)
x=nn.Dense(512)(x)
x=nn.relu(x)
x=nn.Dense(512)(x)
x=nn.relu(x)
x=nn.Dense(1)(x)
returnx# Works fine with this line commented outMLP=nn.remat(MLP, static_argnums=(1,))
model=MLP()
rng_key=jax.random.PRNGKey(42)
variables=model.init(rng_key, input_example, True)
Traceback:
$ python foo.py
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/dion/codes/supersede/foo.py", line 43, in<module>
variables = model.init(rng_key, input_example, True)
jax.errors.ConcretizationTypeError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the functionnew_fun at /Users/dion/.virtualenvs/science/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py:393 forcheckpoint. This concrete value was not availablein Python because it depends on the value of the argument dyn_args[2].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
Consider using the `static_argnums` parameter for`jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` docstring and its example involving `static_argnums`:
https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html
Tested with flax==0.8.4.
The text was updated successfully, but these errors were encountered:
I am using
flax.linen.remat
on a module that has atrain
flag (used to check if the model is training). I'm usingstatic_argnums
on that flag, but am still getting aConcretizationTypeError
on model init.Reproducer:
Traceback:
Tested with
flax==0.8.4
.The text was updated successfully, but these errors were encountered: