I am trying to finetune mT5 model on a single node with 8 Nvidia A100s. I load the model with
datatype = torch.bfloat16
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, cache_dir=cache_dir,torch_dtype=datatype, device_map=“auto”)
and use Seq2SeqTrainer
num_gpus = torch.cuda.device_count()
args = Seq2SeqTrainingArguments(
output_dir=f"/outdir",
evaluation_strategy=“epoch”,
learning_rate=learning_rate,
per_device_train_batch_size=batch_size//num_gpus,
per_device_eval_batch_size=batch_size//num_gpus,
weight_decay=0.01,
save_total_limit=1,
num_train_epochs=num_train_epochs,
predict_with_generate=True,
logging_steps=logging_steps,
push_to_hub=False,
)
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_datasets[“train”],
eval_dataset=tokenized_datasets[“validation”].select(range(32)),
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
python = 3.8.10
torch==2.0.1+cu117
accelerate==0.26.1
CUDA Version: 11.7
When I check the compute and memory usage on GPUs using nvidia-smi, I see that only one GPU is being used for compute. There is small memory usage on others (possibly data batches being copied). But the GPU utilization stays 0% for the other 7 GPUs. I can see in the training logs that 8 different losses are calculated.
I am not able figure out why this is so. Could somebody please provide a hint?