-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
flax nn.tabulate Incorrectly Reports FLOPs and VJP FLOPs #4023
Comments
Hmm unfortunately I cannot repro this (Flax 0.8.5). My printout yields this:
|
The code does work on the pinned package configurations on Colab and Kaggle, but fails to run when installed with the same package versions on a local machine. The provided data is based on a new install of flax, jax and jaxlib cuda on a mamba environment using pip. (Though that shouldn't affect it). For reference, the Colab and Kaggle runtime use system level CuDA packages while the pip installed versions come with their own CuDA wheels. Here's the minimal dependency list anyways.
|
System information
Ubuntu 22.04.4 LTS x86_64
Python 3.12.4
NVIDIA GeForce GTX 3080 Ti
12.2
Problem you have encountered:
When running a script to tabulate the model summary including FLOPs and VJP FLOPs using Flax's
nn.tabulate
function, the output incorrectly shows both FLOPs and VJP FLOPs as 0. This is unexpected as the model does perform computations that should result in a non-zero FLOPs count, and especially the VJP FLOPs should be a non-zero integer value given the model's structure and operations.What you expected to happen:
The expected output should correctly calculate and display the FLOPs and VJP FLOPs for each layer in the model.
Logs, error messages, etc:
The text was updated successfully, but these errors were encountered: