Running low on GPU memory on a cluster with ESM2 lowest config

I’m trying to train a model using esm2_t6_8M_UR50D, which is the smallest ESM2 model (6 layers, 8M parameters), and TrainingArguments with batch_size=1, gradient_accumulation_steps=1 and fp16=True. I’m training it on a cluster with four 32GB GPUs, and I’m seeing it using all four of them at their fullest. I also set PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512 beforehand. However, using this configuration, it runs out of memory at iteration 600 (out of 40k).

What else can I do to make it work?

@davguev what service are you using? are you using a cloud service or do you have your own cluster?

1 Like

Our lab has its own cluster with 4 GPUs and no one else was using them when I tried to train the model.