CUDA error: device-side assert triggered after a certain steps

Hi, I am trying to train a zero-shot topic classification model on XNLI/vi dataset using phobert-base-v2 on Google Colab Pro+.

I got the following error whenever it reaches a certain steps with different batch_size:

RuntimeError: CUDA error: device-side assert triggered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

For batch_size = 16 it got to around 200 steps, with 8 it goes to around 600 steps

I have read through several posts which suggest the problem might come from indexing label. I have checked the index of labels column in the dataset which starts from 0-2. I also try to switch to CPU but it is very slow and consider that the bug only appears when the training came up to current steps, it might not indicate the right issues.

This is my training code

from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback, IntervalStrategy
import numpy as np
import evaluate

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1", average="macro")
precision_metric = evaluate.load("precision", average="macro")
recall_metric = evaluate.load("recall", average="macro")

def compute_metrics(eval_preds):

    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"]
    precision = precision_metric.compute(predictions=predictions, references=labels, average="macro")["precision"]
    recall = recall_metric.compute(predictions=predictions, references=labels, average="macro")["recall"]
    f1 = f1_metric.compute(predictions=predictions, references=labels, average="macro")["f1"]

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

training_args = TrainingArguments(
   output_dir='./zero_shot_topic_classification',
   evaluation_strategy = IntervalStrategy.STEPS,
   eval_steps = 100,
   save_steps = 200,
   logging_steps = 100,
   learning_rate=2e-5,
   per_device_train_batch_size=8,
   per_device_eval_batch_size=8,
   num_train_epochs=50,
   weight_decay=0.01,
   save_strategy=IntervalStrategy.STEPS,
   push_to_hub=False,
   load_best_model_at_end = True,
   metric_for_best_model = 'f1',
   optim="adamw_torch"
)

trainer = Trainer(
   model=model,
   args=training_args,
   train_dataset=train_dataset,
   eval_dataset=validation_dataset,
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
   callbacks = [EarlyStoppingCallback(early_stopping_patience=10)],
)

trainer.train()