I am trying to run the example code I have the following GPU configuration when I run nvidia-smi

1.Could I know which JAX,Flax versions do I need to install from the JAX github repo ?
2. I am also getting the following errors:

2023-06-08 01:38:47.110150: E external/xla/xla/stream_executor/cuda/] Loaded runtime CuDNN library: 8.4.1 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.

I downgraded the JAX,Flax to a lower versions to cudnn 802 then I got the following error

2023-06-08 01:23:26.145096: W external/org_tensorflow/tensorflow/stream_executor/gpu/] *** WARNING *** You are using ptxas 11.0.194, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
2023-06-08 01:23:26.149641: W external/org_tensorflow/tensorflow/stream_executor/gpu/] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6
2023-06-08 01:23:26.149664: W external/org_tensorflow/tensorflow/stream_executor/gpu/] Used ptxas at ptxas
2023-06-08 01:23:26.152107: E external/org_tensorflow/tensorflow/stream_executor/cuda/] failed to get PTX kernel “shift_right_logical_3” from module: CUDA_ERROR_NOT_FOUND: named symbol not found

Any solutions would be greatly appreciated. I am doing currently a test run on the sample example provided, then I will use it pre-train on my own dataset which I have processed.