Segmentation fault with gradient_checkpointing on multiGPU

I am trying to train a wav2vec2 model on my own dataset by following this template.

I have two issues:

  1. The model does not seem to be learning much. I have tried different learning rates and I see differences, but not good enough.

  2. If I set gradient_checkpointing=True the training segfaults (core dumped) when CUDA_VISIBLE_DEVICES is set to more than one GPU (single node). With just one GPU it is OK, no matter which one. And if gradient_checkpointing is not set, the training can take advantage of multiple GPUs. Is this a known issue/feature? Are there any extra options that need to be set?

I am running Cuda 12.1 with the latest driver and the nightly developing version on Pytorch on two RTX 4090.