Segmentation fault with gradient_checkpointing on multiGPU

I am trying to train a wav2vec2 model on my own dataset by following this template.

I have two issues:

  1. The model does not seem to be learning much. I have tried different learning rates and I see differences, but not good enough.

  2. If I set gradient_checkpointing=True the training segfaults (core dumped) when CUDA_VISIBLE_DEVICES is set to more than one GPU (single node). With just one GPU it is OK, no matter which one. And if gradient_checkpointing is not set, the training can take advantage of multiple GPUs. Is this a known issue/feature? Are there any extra options that need to be set?

I am running Cuda 12.1 with the latest driver and the nightly developing version on Pytorch on two RTX 4090.

Hi, did you get it solved by any chance?

Unfortunately, I’m also getting a Seg fault with multi-GPU. Turning off gradient checkpointing is not solving the issue either.
I’m using AMD GPU with the following packages in Python 3.12.3:
pytorch-triton-rocm 3.0.0
torch 2.4.0+rocm6.1
transformers 4.44.2
datasets 2.21.0