Generating text from pretrained-bert based decoder

I have this decoder where I have used pre-trained bert. Training looks fine but I am having problem in inference. I wonder if my approach of greedy generation in the function below is correct.

class TextDecoder(nn.Module):
  def __init__(self, tokenizer, config, device):
    super().__init__()
    self.device=device
    self.config = config
    self.config["add_cross_attention"] = True
    self.config["is_decoder"] = True
    self.tokenizer = tokenizer
    self.decoder_config = config
    self.decoder_config = BertConfig.from_dict(self.config)
    self.decoder = BertLMHeadModel.from_pretrained('bert-base-uncased', config=self.decoder_config).to(self.device)
    self.decoder.config.decoder_start_token_id = tokenizer.cls_token_id
    self.decoder.config.pad_token_id = tokenizer.pad_token_id

  def forward(self, input_ids, attention_mask, enc_out, encoder_attention_mask):
    input_ids = input_ids.to(self.device)
    attention_mask = attention_mask.to(self.device)
    encoder_attention_mask = encoder_attention_mask.to(self.device)
    decoder_input_ids = shift_tokens_right(input_ids, self.tokenizer.pad_token_id, self.tokenizer.cls_token_id)
    # decoder_attention_mask = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)

    decoder_out = self.decoder(input_ids = decoder_input_ids, attention_mask=attention_mask, encoder_hidden_states=enc_out, labels=decoder_input_ids,
                               encoder_attention_mask=encoder_attention_mask, return_dict=True)
    return decoder_out.loss, decoder_out.logits
  
  def generate(self, input_ids, encoder_hidden_states, do_sample=False, max_length=14):
    outputs = text_decoder.decoder.generate(input_ids=input_ids,
                                  encoder_hidden_states=encoder_hidden_states,
                                  do_sample=False,
                                  max_length=14)
    # take max from logits and decode it
    output_sen = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return output_sen
def train(data_loader):
  text_encoder.train()
  text_decoder.train()
  epoch_loss = 0
  with torch.enable_grad():
    for i, (in_ids, in_msk, out_ids, out_msk) in enumerate(data_loader):
      max_len = param['max_length']
      batch_size = param['batch_size']
      vocab_size = param['vocab_size']
      ls_labels = out_ids.clone()
      optimizer.zero_grad()
      encoder_out = text_encoder(in_ids, in_msk)
      
      loss_, logits = text_decoder(out_ten, out_mask, encoder_out, encoder_attention_mask=in_mask)
      loss_.backward()
      optimizer.step()
      epoch_loss += loss_.item()
  return epoch_loss/len(data_loader)
history = {'train_loss':[], 'val_loss':[]}

for epoch in range(20):
  train_loss = train(train_dataloader)
  val_loss = train(val_dataloader)
  history['train_loss'].append(train_loss)
  history['val_loss'].append(val_loss)
  print(f'Epoch:{epoch} Train Loss: {train_loss} Val Loss: {val_loss}')

My loss function:
image

This is my Generation during inference:

def generate_greedy(in_ten, in_mask):
  '''Two approaches might be to take the logits, convert them to a distribution and (1) take argmax or (2) sample from the distribution.'''
  text_encoder.eval()
  text_decoder.eval()
  text_src = in_ten.to(device)
  text_mask = in_mask.to(device)
  text_feats = text_encoder(in_seq=text_src, in_mask=text_mask)
  input_ids = torch.zeros((text_feats.shape[0], 1)).long().to(device)
  input_ids[:, 0] = 101
  output_sen = text_decoder.generate(input_ids, text_feats)
  input_sen = tokenizer.batch_decode(in_ten, skip_special_tokens=True)
  return input_sen, output_sen

But the output is:

Input sentence:  a group of men are loading cotton onto a truck
Output sentence:  a man in a suit is running past two other gentleman also also
******************************
Input sentence:  a man sleeping in a green room on a couch
Output sentence:  a man in a suit is running past two other gentleman also also
******************************
Input sentence:  a boy wearing headphones sits on a womans shoulders
Output sentence:  a man in a suit is running past two other gentleman also also
******************************
Input sentence:  two men setting up a blue ice fishing hut on an iced
Output sentence:  a man in a suit is running past two other gentleman also also
******************************
Input sentence:  a balding man wearing a red life jacket is sitting in
Output sentence:  a man in a suit is running past two other gentleman also also

Hi @sgugger @patrickvonplaten, I am sorry for pulling up here. But I was really looking for a help. I would be grateful if you could help me.

Thanks