Wav2vec2 CUDA OOM with distributed training

I’m trying to reproduce Distributed Training for Wav2vec2 model.
My computation resource have :
A server with 128 core, 2TB RAM, 8x A100 x40GB
About the datasets, I’m trying to train Chinese ASR acoustic model with about 580GB WAV files.
My reproceduce code in this repo :

My question is, when I set same the configuration in 1 single A100, it’s worked nice
But when I run with all 8 GPUs, after few minutes, the CUDA 'OOM appears.
What wrong with my code ?

My reproceduce code in here