Trainer object high memory usage on Google Cloud Platform Workbench instance

I am trying to fine-tune the audio spectrogram transformer on the ESC-50 dataset and am running into some issues getting the code to work on a Google Cloud Platform workbench instance. When I run it locally (RTX 4060 GPU, python310, torch2.4.0+cu121) I seem to be able to run without issues following this example: Google Colab

However, when working on the cloud, I seem to run into memory limits, even though the VM has plenty of specs (8vCPU, 30GB ram and a T4 with 15GB of VRAM).

I noted that when I create the Trainer object, it uses up a whopping 11GB of VRAM, almost maxing out the GPU. For comparison, when I locally create the trainer object, it only reserves around 1.5GB of VRAM.

In addition, it gives several warning messages that I don’t understand:
“Found CUDA without GPU_NUM_DEVICES. Defaulting to PJRT_DEVICE=CUDA with GPU_NUM_DEVICES=1”

“XLA service 0x55cce98be1f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:”

“successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero”

If I run trainer.evaluate(), it runs despite the warning. However if I run trainer.train(), it fails to run and runs into memory limitations.

Once again the same code runs without issues on my 8GB RTX4060 locally.

Anyone here have experience with running HuggingFace transformers on Google Cloud Platform and could help me understand this behaviour?

from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
import evaluate
from datetime import datetime
import numpy as np


# Generate a datestring for filenames
now = datetime.now() # current date and time
date_time = now.strftime("%Y%m%d-%H%M%S")
print(date_time)

num_labels = len(id2label)
model = AutoModelForAudioClassification.from_pretrained(
    model_checkpoint, 
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

model_name = model_checkpoint.split("/")[-1]

batch_size = 8

args = TrainingArguments(
    f"{date_time}-{model_name}",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    eval_steps=1,
    warmup_ratio=0.1,
    logging_steps=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

accuracy = evaluate.load("gza_data/evaluate/metrics/accuracy/accuracy.py")
recall = evaluate.load("gza_data/evaluate/metrics/recall/recall.py")
precision = evaluate.load("gza_data/evaluate/metrics/precision/precision.py")
f1 = evaluate.load("gza_data/evaluate/metrics/f1/f1.py")
roc_auc_score = evaluate.load("gza_data/evaluate/metrics/roc_auc") # binary config. use ,"multiclass"  for multiclass config
    
#AVERAGE = "macro" if config.num_labels > 2 else "binary"
AVERAGE = "macro"

def compute_metrics(eval_pred):
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=1)
    metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    metrics.update(precision.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    #metrics.update(roc_auc_score.compute(prediction_scores=predictions, references=eval_pred.label_ids)) # add macro for multiclass
    return metrics

trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"].with_format("torch"),
    eval_dataset=encoded_dataset["validation"].with_format("torch"),
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics
)

trainer.evaluate()
trainer.train()