Calling python run_t5_mlm_flax.py when running on multiple GPU

Hello,

I am trying running run_t5_mlm_flax.py with 8 GPU but I get this error.
NCCL operation ncclAllReduce(send_buffer, recv_buffer, element_count, dtype, reduce_op, comm, gpu_stream) failed: unhandled cuda error
Let me know if you have a suggestion! :slight_smile:

I’ve run this script in a GPU-cluster that has 8 GPUs on a single node and it has worked as expected. Are you running a distributed or a local setup? Inspecting export NCCL_DEBUG=INFO could give some insight here. For example, in my setup export NCCL_SOCKET_IFNAME=hsn0,hsn1,hsn2,hsn3 solved some NCCL-issues.

I’m struggling with getting a multi-node setup going. Is it only a matter of having different launch configuration or does the script require changes e.g. for handling data sharding etc. I’m just starting with JAX/FLAX and have no previous experience. If somebody knows, please let me know! :slight_smile: