GPU crashing when calling compute_metrics

So I have been trying to run the ViT-GPT2 transformer model ( nlpconnect/vit-gpt2-image-captioning 路 Hugging Face) on Paperspace.
Here is the hardware configuration of the GPU that I am using:
Name: Nvidia RTX a6000
CPU: 8
GPU: 48 GB
RAM: 45 GB

Here is my code for

from transformers import VisionEncoderDecoderModel,AutoTokenizer,AutoImageProcessor
model_name = "nlpconnect/vit-gpt2-image-captioning"
model=VisionEncoderDecoderModel.from_pretrained(model_name).to("cuda")
tokenizer=AutoTokenizer.from_pretrained(model_name)
feature_extractor=AutoImageProcessor.from_pretrained(model_name)

# data preprocessing

from torch.utils.data import Dataset
import torch
import torchvision.transforms as T
from PIL import Image
class CustomDataset(Dataset):
    def __init__(self, images, captions):
        self.images = images
        self.captions = captions
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        caption = self.captions[idx]
        # transform = T.ToTensor()
        # img_tensor = transform(image)
        # mean = torch.mean(img_tensor) 
        # std = torch.std(img_tensor)
        # normalized = (img_tensor - mean) / std
        # preprocessed_image_tensor= normalized
        image_tensor= feature_extractor(images=image, return_tensors="pt").pixel_values.squeeze()
        encoded_text=tokenizer.encode_plus(caption,padding='max_length',truncation=True,return_tensors="pt").input_ids
        inputs = {
            'pixel_values': image_tensor,
            'labels': encoded_text,
        }
        
        return inputs

# Training

from transformers import TrainingArguments,Trainer
training_args = TrainingArguments(
    learning_rate=5e-5,
    num_train_epochs=2,
    fp16=True,
    per_device_train_batch_size=2,  # Reduce the batch size to a small value
    per_device_eval_batch_size=2,   # Set evaluation batch size
    gradient_accumulation_steps=10,  # Increase gradient accumulation steps
    save_total_limit=3,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    logging_steps=100,
    remove_unused_columns=False,
    output_dir="./image-captioning-output",
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
 )

this is running fine and giving me the following output


Step	 Training Loss	  Validation Loss
100	     0.032100	      0.031620
200	     0.030400	     0.031046
TrainOutput(global_step=234, training_loss=0.030886615698154155, metrics={'train_runtime': 422.8986, 'train_samples_per_second': 11.1, 'train_steps_per_second': 0.553, 'total_flos': 8.443905465510789e+17, 'train_loss': 0.030886615698154155, 'epoch': 1.99})

but the moment I insert compute metrics (I am using word error rate as the metric it throws me the following error)

# code for compute metrics using word error rate
from evaluate import load
import torch

wer = load("wer")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predicted = logits.argmax(-1)
    decoded_labels = model.decoder.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = model.decoder.batch_decode(predicted, skip_special_tokens=True)
    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
    return {"wer_score": wer_score}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics = compute_metrics
 )


It throws me the following error

---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In [41], line 3
      1 os.environ["WANDB_DISABLED"] = "true"
----> 3 trainer.train()
      4 # ed7627d11f2e67bfa713e1d13f087a9a4225a3a6 API key

File /usr/local/lib/python3.9/dist-packages/transformers/trainer.py:1539, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1534     self.model_wrapped = self.model
   1536 inner_training_loop = find_executable_batch_size(
   1537     self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1538 )
-> 1539 return inner_training_loop(
   1540     args=args,
   1541     resume_from_checkpoint=resume_from_checkpoint,
   1542     trial=trial,
   1543     ignore_keys_for_eval=ignore_keys_for_eval,
   1544 )

File /usr/local/lib/python3.9/dist-packages/transformers/trainer.py:1901, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1898     self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   1899     self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 1901     self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
   1902 else:
   1903     self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

File /usr/local/lib/python3.9/dist-packages/transformers/trainer.py:2226, in Trainer._maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)
   2224         metrics.update(dataset_metrics)
   2225 else:
-> 2226     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2227 self._report_to_hp_search(trial, self.state.global_step, metrics)
   2229 # Run delayed LR scheduler now that metrics are populated

File /usr/local/lib/python3.9/dist-packages/transformers/trainer.py:2934, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   2931 start_time = time.time()
   2933 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 2934 output = eval_loop(
   2935     eval_dataloader,
   2936     description="Evaluation",
   2937     # No point gathering the predictions if there are no metrics, otherwise we defer to
   2938     # self.args.prediction_loss_only
   2939     prediction_loss_only=True if self.compute_metrics is None else None,
   2940     ignore_keys=ignore_keys,
   2941     metric_key_prefix=metric_key_prefix,
   2942 )
   2944 total_batch_size = self.args.eval_batch_size * self.args.world_size
   2945 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:

File /usr/local/lib/python3.9/dist-packages/transformers/trainer.py:3148, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   3146         logits = self.preprocess_logits_for_metrics(logits, labels)
   3147     logits = self.accelerator.gather_for_metrics((logits))
-> 3148     preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
   3150 if labels is not None:
   3151     labels = self.accelerator.gather_for_metrics((labels))

File /usr/local/lib/python3.9/dist-packages/transformers/trainer_pt_utils.py:114, in nested_concat(tensors, new_tensors, padding_index)
    110 assert type(tensors) == type(
    111     new_tensors
    112 ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
    113 if isinstance(tensors, (list, tuple)):
--> 114     return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
    115 elif isinstance(tensors, torch.Tensor):
    116     return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)

File /usr/local/lib/python3.9/dist-packages/transformers/trainer_pt_utils.py:114, in <genexpr>(.0)
    110 assert type(tensors) == type(
    111     new_tensors
    112 ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
    113 if isinstance(tensors, (list, tuple)):
--> 114     return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
    115 elif isinstance(tensors, torch.Tensor):
    116     return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)

File /usr/local/lib/python3.9/dist-packages/transformers/trainer_pt_utils.py:116, in nested_concat(tensors, new_tensors, padding_index)
    114     return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
    115 elif isinstance(tensors, torch.Tensor):
--> 116     return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
    117 elif isinstance(tensors, Mapping):
    118     return type(tensors)(
    119         {k: nested_concat(t, new_tensors[k], padding_index=padding_index) for k, t in tensors.items()}
    120     )

File /usr/local/lib/python3.9/dist-packages/transformers/trainer_pt_utils.py:75, in torch_pad_and_concatenate(tensor1, tensor2, padding_index)
     72 tensor2 = atleast_1d(tensor2)
     74 if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
---> 75     return torch.cat((tensor1, tensor2), dim=0)
     77 # Let's figure out the new shape
     78 new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:]

OutOfMemoryError: CUDA out of memory. Tried to allocate 15.34 GiB (GPU 0; 47.54 GiB total capacity; 23.72 GiB already allocated; 14.79 GiB free; 31.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Is there any other way that I can simultaneously compute the metrics and do the training with the hugging face transformer?