torch.cuda.OutOfMemoryError when evaluate while traning

I try to train entailment model for zeroshot leanring
When traning the GPU mem just use 5, 6G but when it evaluate for each 50 steps, it turn out out of Memory GPU
My train set: 6 milion samples, my test set: 2 milion samples
Thanks for helping me
Here my code:

import sys
sys.path.append('../../')
import evaluate
import numpy as np
import torch
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification
from transformers import Trainer, TrainingArguments
from transformers import EvalPrediction, EarlyStoppingCallback, IntervalStrategy, PrinterCallback
from datasets import load_dataset

device = torch.device("cuda")
print(device)
model_path = "../../pretrain_model/phobert-base"
n_labels = 2
max_length = 256
print('load tokenize model')
tokenizer = AutoTokenizer.from_pretrained(model_path)
print('done load')

def encode(batch):
    encoding = tokenizer(batch['sentence1'], batch['sentence2'], return_token_type_ids=False, max_length=max_length, truncation='longest_first', padding='max_length', return_tensors='pt')
    return encoding


accuracy = evaluate.load("../../evaluate/metrics/accuracy")
f1_score = evaluate.load("../../evaluate/metrics/f1")
precision = evaluate.load("../../evaluate/metrics/precision")
recall = evaluate.load("../../evaluate/metrics/recall")
metric_name = "f1_score"

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.argmax(preds, axis=1)
    labels = p.label_ids
    results = {
        'accuracy':list(accuracy.compute(predictions=preds, references=labels).values())[0],
        'f1_score':list(f1_score.compute(predictions=preds, references=labels).values())[0],
        'precision':list(precision.compute(predictions=preds, references=labels).values())[0],
        'recall':list(recall.compute(predictions=preds, references=labels).values())[0]
    }
    return results


id2label = {
 0: 'contradiction', 
 1: 'entailment', 
}

label2id = {
 'contradiction': 0, 
 'entailment': 1, 
}

model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=n_labels, id2label=id2label, label2id=label2id)
model = model.to(device)

for name, param in model.named_parameters():
    if 'classification' not in name: # classifier layer
        param.requires_grad = False
        
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, max_length=max_length, padding='max_length', return_tensors='pt')

dataset = load_dataset("csv", data_files={"train": "../dataset/train_zeroshot_v2.csv", "test": "../dataset/test_zeroshot_v2.csv"})
print('process data')

trainset_processed = dataset['train'].map(encode, batched=True, cache_file_name='/home/jovyan/.cache/huggingface/datasets/absa_zeroshot_train_mapped')
testset_processed = dataset['test'].map(encode, batched=True, cache_file_name='/home/jovyan/.cache/huggingface/datasets/absa_zeroshot_test_mapped')

trainset_processed = trainset_processed.remove_columns(["sentence1", "sentence2"])
testset_processed = testset_processed.remove_columns(["sentence1", "sentence2"])

trainset_processed = trainset_processed.rename_column("label", "labels")
testset_processed = testset_processed.rename_column("label", "labels")


trainset_processed = trainset_processed.with_format("torch", device=device)
testset_processed = testset_processed.with_format("torch", device=device)

training_args = TrainingArguments(
    output_dir="ZeroShot",
    logging_dir='logger',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,

    num_train_epochs=100,
    weight_decay=0.01,
    logging_strategy="steps",
    evaluation_strategy="steps",
    save_strategy="steps",
    logging_steps=50,
    save_total_limit=10,
    metric_for_best_model=metric_name,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=trainset_processed,
    eval_dataset=testset_processed,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]
)

trainer.train()