Increasing eval batch size in trainer api causes size mismatch during evaluation

Hello all,
I am trying distillation of wav2vec2base-960h on timit dataset by having student model response match teacher model using kl_div loss . For this I am wrote CustomTrainer for compute_loss function

from transformers import Trainer
from torch.nn.functional import kl_div

class CustomTrainer(Trainer):
    def __init__(self, model,args = None,data_collator = None,train_dataset = None,eval_dataset = None,tokenizer = None,
            model_init = None,compute_metrics = None,callbacks = None,optimizers = (None,None)):
            
        super(CustomTrainer,self).__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init,
            compute_metrics, callbacks, optimizers)
            
    def compute_loss(self, model_obj, inputs, return_outputs=False):
        labels = inputs["labels"]
        
        outputs_student = model_obj(**inputs)
        with torch.no_grad():
            outputs_teacher = teacher_model(**inputs)

        # outputs_student = model_obj.forward(inputs['input_values'].to(DEVICE), is_student=True)
        # outputs_teacher = model_obj.forward(inputs['input_values'].to(DEVICE), is_student=False)

        student_logits = outputs_student.get("logits")
        teacher_logits = outputs_teacher.get("logits")

        # print(student_logits.shape, teacher_logits.shape)

        # matching teacher response loss with student 
        kl_div_loss = kl_div(student_logits, teacher_logits, reduction='batchmean')
        
        return (kl_div_loss, student_logits) if return_outputs else kl_div_loss

Now the issue is whenever I pass per_device_eval_batch_size more than 2,I am getting the following error which doesn’t occur when eval batch size is less than 2

training_args = TrainingArguments(
  disable_tqdm=False,
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=32,
  per_device_eval_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=True,
  gradient_checkpointing=True, 
  save_steps=200,
  eval_steps=200,
  logging_steps=200,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
  # push_to_hub=True
)

trainer = CustomTrainer(
    model=student_model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor.feature_extractor,
)

trainer.train()

Any help as to why this occurs? Thanks in advance…