How to train a custom EncoderDecoderModel with KLD loss

I am trying to create a custom EncoderDecoder model with Bert as both encoder and decoder ( I have followed the following article for reference: link_here
What I have done so far is:

class bertGenModelKLD(EncoderDecoderModel):
    def __init__(self, config):
        super(bertGenModelKLD, self).__init__(config)
        print('-'*50, 'KLD', '-'*50)
        self.bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
        self.bert2bert.config.decoder_start_token_id = config.decoder_start_token_id
        self.bert2bert.config.eos_token_id = config.eos_token_id
        self.bert2bert.config.pad_token_id = config.pad_token_id
        # #self.bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
        self.bert2bert.config.vocab_size = config.encLength
        self.bert2bert.config.max_length = config.max_length
        self.bert2bert.config.no_repeat_ngram_size = config.no_repeat_ngram_size
        self.bert2bert.config.early_stopping = config.early_stopping 
        self.bert2bert.config.length_penalty = config.length_penalty
        self.bert2bert.config.num_beams = config.num_beams
        self.bert2bert.encoder.resize_token_embeddings(config.encLength)
        self.bert2bert.decoder.resize_token_embeddings(config.decLength)
        print(self.bert2bert.config.vocab_size, '<<<vocab_size')
        self.encoder = self.bert2bert.encoder
        self.decoder = self.bert2bert.decoder

Now, I want to use KLD as a loss function instead of standard NLL. To do that, I have tried something like: I have added a forward function in the bertGenModelKLD class above

def forward(self, input_ids, attention_mask, labels):
        outputs = self.bert2bert(input_ids = input_ids, attention_mask = attention_mask, decoder_input_ids = labels, labels = labels, return_dict = True)
        loss, logits = outputs.loss, outputs.logits
        loss_ = None
        if(labels is not None):
            loss_funct = torch.nn.KLDivLoss(reduction="batchmean")
            loss_ = loss_funct(logits.view(-1, self.bert2bert.config.vocab_size), labels.view(-1))
        return Seq2SeqModelOutput(loss = loss, logits = logits, hidden_states = outputs.hidden_states, attentions=outputs.attentions)

My questions are:

  1. Am I doing it correctly? Because I am getting an error

“The size of tensor a (600) must match the size of tensor b (30525) at non-singleton dimension 1”

  1. How can I extend EncoderDecoderModel class to write a custom model with custom loss, as I have tried above?

Any help will be greatly appreciated. /\