I am trying to run the example code run_t5_mlm_flax.py. I have the following GPU configuration when I run nvidia-smi
1.Could I know which JAX,Flax versions do I need to install from the JAX github repo ?
2. I am also getting the following errors:
2023-06-08 01:38:47.110150: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.4.1 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
I downgraded the JAX,Flax to a lower versions to cudnn 802 then I got the following error
2023-06-08 01:23:26.145096: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:111] *** WARNING *** You are using ptxas 11.0.194, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.
You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
2023-06-08 01:23:26.149641: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:230] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6
2023-06-08 01:23:26.149664: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:233] Used ptxas at ptxas
2023-06-08 01:23:26.152107: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:632] failed to get PTX kernel “shift_right_logical_3” from module: CUDA_ERROR_NOT_FOUND: named symbol not found
Any solutions would be greatly appreciated. I am doing currently a test run on the sample example provided, then I will use it pre-train on my own dataset which I have processed.