Finetuning Pegasus for summarization by splitting the encoder

Hello, I am fine-tuning Pegasus on a summarization task and want to integrate a domain adaptation script into the training, which would require me to separate out the encoder and decoder objects of the PegasusForConditionalGeneration Class. To this end, this is the code I have so far:

   model = SummarizerNet(model_name, args)
   encoder = model.encoder
   decoder = model.decoder
   lm_head = model.lm_head
   vocab_size = model.vocab_size
   
   
   if torch.cuda.device_count()>1:
       model = nn.DataParallel(model)
   
   model.to(args.device)

   src_loader = DataLoader(src_dataset, batch_size=args.batch_size,
                               shuffle=True, num_workers=1, pin_memory=True, drop_last=True)
   n_batches = len(src_loader)
   
   optim = torch.optim.AdamW(list(model.parameters()), weight_decay=0.01)
   summarizer_criterion = nn.CrossEntropyLoss()
   
   for epoch in range(args.num_epochs):
       model.train()
       total_summarization_loss = 0
                       
       for src_batch in tqdm(src_loader, leave=False, total=n_batches):
                       
           # unroll all items from batch dictionary
           src_input_ids, src_attention_mask, src_decoder_input_ids, src_decoder_attention_mask, src_labels = (v.to(args.device) for _,v in src_batch.items())

           # shift the decoder input_ids by 1 
           if torch.cuda.device_count()>1:
               src_decoder_input_ids = shift_tokens_right(src_labels, model.module.config.pad_token_id, model.module.config.decoder_start_token_id)
               
           else:
               src_decoder_input_ids = shift_tokens_right(src_labels, model.config.pad_token_id, model.config.decoder_start_token_id)
           
           # pass the source domain documents through the encoder
           encoder_features_src = encoder(input_ids=src_input_ids, attention_mask=src_attention_mask).last_hidden_state
           
           # pass the encoder hidden state and the source reference summaries to the decoder
           decoder_outputs = decoder(input_ids=src_decoder_input_ids, attention_mask = src_decoder_attention_mask,
                                     encoder_hidden_states = encoder_features_src, encoder_attention_mask = src_attention_mask).last_hidden_state
           
           # create the final logits bias (copied from src/transformers/modeling_pegasus.py)
           if torch.cuda.device_count()>1:
               final_logits_bias = model.module.model.final_logits_bias
           else:
               final_logits_bias = model.model.final_logits_bias
           
           lm_head_output = lm_head(decoder_outputs) + final_logits_bias
           
           loss = summarizer_criterion(lm_head_output.reshape(-1, vocab_size), src_labels.reshape(-1))
           
           optim.zero_grad()
           loss.backward()
           optim.step()
           
           total_summarization_loss += loss.item()
       
       mean_summarization_loss = total_summarization_loss / n_batches
       tqdm.write(f'EPOCH {epoch+1:03d}:'
                   f'summarization_loss={mean_summarization_loss:.4f}')

       save_dir = '/'.join(save_path.split('/')[:-1])
       if not os.path.exists(save_dir):
           os.makedirs(save_dir)
       if torch.cuda.device_count() > 1:
           model_checkpoint = {'model_state_dict': model.module.model.state_dict(), 'optimizer': optim.state_dict()}
       else:
           model_checkpoint = {'model_state_dict': model.model.state_dict(), 'optimizer': optim.state_dict()}
       torch.save(model_checkpoint, save_path)

The helper classes and functions I have used are as follows:

class SummarizerNet(nn.Module):
    def __init__(self, model_name, args):
        super().__init__()
        self.model = PegasusForConditionalGeneration.from_pretrained(model_name)
        self.config = self.model.config
        self.encoder = self.model.get_encoder()
        self.decoder = self.model.get_decoder()
        self.lm_head = self.model.lm_head
        self.vocab_size = self.model.config.vocab_size
        self.model.post_init()

    def forward(self, x):
        features = self.encoder(x)
        logits = self.decoder(features)
        return logits
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids
def prepare_data(model_name,
                 train_texts, train_labels,
                 val_texts=None, val_labels=None,
                 test_texts=None, test_labels=None):
    """
    Prepare input data for model fine-tuning
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    prepare_val = False if val_texts is None or val_labels is None else True
    prepare_test = False if test_texts is None or test_labels is None else True
    
    train_dataset = tokenize_data(train_texts, train_labels, tokenizer)
    val_dataset = tokenize_data(val_texts, val_labels, tokenizer) if prepare_val else None
    test_dataset = tokenize_data(test_texts, test_labels, tokenizer) if prepare_test else None

    return train_dataset, val_dataset, test_dataset, tokenizer
def tokenize_data(texts, labels, tokenizer):
    encodings = tokenizer(texts, truncation=True, padding='max_length', max_length = 512)
    decodings = tokenizer(labels, truncation=True, padding='max_length', max_length = 256)
    dataset_tokenized = PegasusDatasetDA(encodings, decodings, tokenizer)
    return dataset_tokenized

The issue I am facing is that my model is not generating as coherent summaries as the un-split model (ie, passing through the PegasusForConditionalGeneration Class itself). Where am I messing up?