RuntimeError: The size of tensor a (3) must match the size of tensor b (512) at non-singleton dimension 2

Hello,
I have been working with modifying class weights for fine-tuning BioBERT for imbalanced dataset NCBI. I have a runtime error with the size of tensors.

Here is the code for class WeightedLossTrainer

class WeightedLossTrainer(Trainer):
     def compute_loss(self,model,input,return_output=False):
          labels=input.get("ner_tags")
          print(labels)
          outputs = model(**input)
          logits = outputs.get("logits")
          loss_function = nn.CrossEntropyLoss(torch.FloatTensor([New_weight_sum0,New_weight_sum1,New_weight_sum2]))
          if self.args.past_index >= 0:
               self._past = outputs[self.args.past_index]
          loss = loss_function(logits, labels)
          loss_function.backward()
          #loss = loss_function(logits.view(-1, model.num_labels), labels.view(-1))
          return (loss(logits, labels), outputs) if return_output else loss(logits, labels)
          outputs=self.model(**input)
          logits= outputs.logits
          self.labels=input.get("nertags")
          self.loss_func=nn.CrossEntropyLoss(torch.FloatTensor([New_weight_sum0,New_weight_sum1,New_weight_sum2]))
          if self.loss_func is not None:
               loss=self.loss_func(logits,self.labels)
          if logits is not None and self.labels is not None:
              loss=self.loss_func(logits,self.labels)
          if logits is not None and self.labels is not None:
              loss= self.loss_func(logits.view(-1, model.num_labels), self.labels.view(-1))
              logits=outputs.get("logits")
          if self.args.past_index >= 0:
               self._past =self.outputs[self.args.past_index]
          if self.labels is not None:
                 loss = self.loss_func(logits, self.labels)
          else:
               self.loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
               return(loss,outputs) if return_output else loss

Here is the link to the whole code

Any help to resolve this issue?

Best,
Ghadeer