I’m trying to train a model using esm2_t6_8M_UR50D, which is the smallest ESM2 model (6 layers, 8M parameters), and TrainingArguments with batch_size=1
, gradient_accumulation_steps=1
and fp16=True
. I’m training it on a cluster with four 32GB GPUs, and I’m seeing it using all four of them at their fullest. I also set PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
beforehand. However, using this configuration, it runs out of memory at iteration 600 (out of 40k).
What else can I do to make it work?