-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problem in usage of multi-gpu (Four RTX 4090) for FLAX #3931
Comments
The error message you posted all looks like normal INFO printouts - is there a more specific error message or stack trace, or did the program just crashed after these printouts? In general, hardware issues like multi-gpu are more likely rooted in JAX, as Flax rarely directly touch lower level APIs. I'd also recommend trying some smaller, pure-JAX code (like from this multi-device guide or other JAX website sample code) to pinpoint the error to more specific lines. |
Thanks for your reply!
Actually, the For your question is there a more specific error message or stack trace, or did the program just crashed after these printouts?, the program just crashed after this printout without further progress (I have waited more than 6 hours but nothing gained) I found that since RTX 4090 serives, nvlink is not equipped. Would this be the reason for this error? Thanks! |
From your description it sounds like the program is blocked, instead of fail and exit immediately? Another thing worth doing is adding a ton of prints in your code to bisect which line it is blocked at. Also just FYI, |
Thanks for your reply!
The info message is as follows:
... and no further progress. Thanks! |
Do you have any printout in |
Thanks for your comment!
Output:
Seems the state is not replicated? Also, could you check whether there is some problem in |
Dear FLAX community,
System information
OS Platform and Distribution: Ubuntu 22.04.3 LTS
Flax, jax, jaxlib versions : Flax: 0.8.1 / jax: 0.4.27 / jaxlib: 0.4.27+cuda12.cudnn89
Python version: 3.10
GPU/TPU model and memory & CUDA version (if applicable):
Problem you have encountered:
As shown in the image above, my server computer is equipped with 4 RTX 4090 GPUs. I tried to run batch-training through multi-gpu, but it didn't work with error message below. For me, it seems like the problem comes from the NVIDIA GPU, not from Python.
What you expected to happen:
I want to use multi-gpu for batch-training in FLAX in my server computer environment. How can I fix my code or re-build the environment? (I am quite new to Linux...)
Logs, error messages, etc:
Error message is as follows:
Steps to reproduce:
I followed this benchmark code: (https://colab.research.google.com/drive/1hXns2b6T8T393zSrKCSoUktye1YlSe8U?usp=sharing#scrollTo=oKcRiQ89xQkF) and fixed several issues. The code used for my server is as follows:
(In this paragraph, the error message appears)
It seems that there are lots of people suffering from multi-gpu environment with RTX 4090??
index >= -sizes[i] && index < sizes[i] && "index out of bounds"
failed.) huggingface/transformers#24056Thanks for reading!
The text was updated successfully, but these errors were encountered: