Inheriting from BartForConditionalGeneration into a new class - weight not initializing

Failing to load and initiate the pre-trained BART model when inheriting it from the BartForConditionalGeneration

from transformers.models.bart.modeling_bart import BartForConditionalGeneration,BartPretrainedModel,BartConfig
import torch.nn as nn

import torch

class BartExp(BartPretrainedModel):
    def __init__(self, config: BartConfig):
        super().__init__(config=config)
        self.bart = BartForConditionalGeneration(config)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        output = self.model(input_ids=input_ids,attention_mask=attention_mask,
                     labels=labels,decoder_input_ids=decoder_input_ids,encoder_outputs=encoder_outputs)
        return output


model = BartExp.from_pretrained('facebook/bart-base')

And then I get that endless warning :

    Some weights of the model checkpoint at facebook/bart-base were not used when initializing BartEXP: ['model.shared.weight', 'model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.encoder.layers.0.self_attn.k_proj.weight', 'model.encoder.layers.0.self_attn.k_proj.bias', 'model.encoder.layers.0.self_attn.v_proj.weight', 'model.encoder.layers.0.self_attn.v_proj.bias', 'model.encoder.layers.0.self_attn.q_proj.weight', 'model.encoder.layers.0.self_attn.q_proj.bias', 'model.encoder.layers.0.self_attn.out_proj.weight', 'model.encoder.layers.0.self_attn.out_proj.bias', 'model.encoder.layers.0.self_attn_layer_norm.weight', 'model.encoder.layers.0.self_attn_layer_norm.bias', 'model.encoder.layers.0.fc1.weight', 'model.encoder.layers.0.fc1.bias', 'model.encoder.layers.0.fc2.weight', 'model.encoder.layers.0.fc2.bias', 'model.encoder.layers.0.final_layer_norm.weight', 'model.encoder.layers.0.final_layer_norm.bias', 'model.encoder.layers.1.self_attn.k_proj.weight', 'model.encoder.layers.1.self_attn.k_proj.bias', 'model.encoder.layers.1.self_attn.v_proj.weight', 'model.encoder.layers.1.self_attn.v_proj.bias', 'model.encoder.layers.1.self_attn.q_proj.weight', 'model.encoder.layers.1.self_attn.q_proj.bias', 'model.encoder.layers.1.self_attn.out_proj.weight', 'model.encoder.layers.1.self_attn.out_proj.bias', 'model.encoder.layers.1.self_attn_layer_norm.weight', 'model.encoder.layers.1.self_attn_layer_norm.bias', 'model.encoder.layers.1.fc1.weight', 'model.encoder.layers.1.fc1.bias', 'model.encoder.layers.1.fc2.weight', 'model.encoder.layers.1.fc2.bias', 'model.encoder.layers.1.final_layer_norm.weight', 'model.encoder.layers.1.final_layer_norm.bias', 'model.encoder.layers.2.self_attn.k_proj.weight', 'model.encoder.layers.2.self_attn.k_proj.bias', 'model.encoder.layers.2.self_attn.v_proj.weight', 'model.encoder.layers.2.self_attn.v_proj.bias', 'model.encoder.layers.2.self_attn.q_proj.weight', 'model.encoder.layers.2.self_attn.q_proj.bias', 'model.encoder.layers.2.self_attn.out_proj.weight', 'model.encoder.layers.2.self_attn.out_proj.bias', 'model.encoder.layers.2.self_attn_layer_norm.weight', 'model.encoder.layers.2.self_attn_layer_norm.bias', 'model.encoder.layers.2.fc1.weight', 'model.encoder.layers.2.fc1.bias', 'model.encoder.layers.2.fc2.weight', 'model.encoder.layers.2.fc2.bias', 'model.encoder.layers.2.final_layer_norm.weight', 'model.encoder.layers.2.final_layer_norm.bias', 'model.encoder.layers.3.self_attn.k_proj.weight', 'model.encoder.layers.3.self_attn.k_proj.bias', 'model.encoder.layers.3.self_attn.v_proj.weight', 'model.encoder.layers.3.self_attn.v_proj.bias', 'model.encoder.layers.3.self_attn.q_proj.weight', 'model.encoder.layers.3.self_attn.q_proj.bias', 'model.encoder.layers.3.self_attn.out_proj.weight', 'model.encoder.layers.3.self_attn.out_proj.bias', 'model.encoder.layers.3.self_attn_layer_norm.weight', 'model.encoder.layers.3.self_attn_layer_norm.bias', 'model.encoder.layers.3.fc1.weight', 'model.encoder.layers.3.fc1.bias', 'model.encoder.layers.3.fc2.weight', 'model.encoder.layers.3.fc2.bias', 'model.encoder.layers.3.final_layer_norm.weight', 'model.encoder.layers.3.final_layer_norm.bias', 'model.encoder.layers.4.self_attn.k_proj.weight', 'model.encoder.layers.4.self_attn.k_proj.bias', 'model.encoder.layers.4.self_attn.v_proj.weight', 'model.encoder.layers.4.self_attn.v_proj.bias', 'model.encoder.layers.4.self_attn.q_proj.weight', 'model.encoder.layers.4.self_attn.q_proj.bias', 'model.encoder.layers.4.self_attn.out_proj.weight', 'model.encoder.layers.4.self_attn.out_proj.bias', 'model.encoder.layers.4.self_attn_layer_norm.weight', 'model.encoder.layers.4.self_attn_layer_norm.bias', 'model.encoder.layers.4.fc1.weight', 'model.encoder.layers.4.fc1.bias', 'model.encoder.layers.4.fc2.weight', 'model.encoder.layers.4.fc2.bias', 'model.encoder.layers.4.final_layer_norm.weight', 'model.encoder.layers.4.final_layer_norm.bias', 'model.encoder.layers.5.self_attn.k_proj.weight', 'model.encoder.layers.5.self_attn.k_proj.bias', 'model.encoder.layers.5.self_attn.v_proj.weight', 'model.encoder.layers.5.self_attn.v_proj.bias', 'model.encoder.layers.5.self_attn.q_proj.weight', 'model.encoder.layers.5.self_attn.q_proj.bias', 'model.encoder.layers.5.self_attn.out_proj.weight', 'model.encoder.layers.5.self_attn.out_proj.bias', 'model.encoder.layers.5.self_attn_layer_norm.weight', 'model.encoder.layers.5.self_attn_layer_norm.bias', 'model.encoder.layers.5.fc1.weight', 'model.encoder.layers.5.fc1.bias', 'model.encoder.layers.5.fc2.weight', 'model.encoder.layers.5.fc2.bias', 'model.encoder.layers.5.final_layer_norm.weight', 'model.encoder.layers.5.final_layer_norm.bias', 'model.encoder.layernorm_embedding.weight', 'model.encoder.layernorm_embedding.bias', 'model.decoder.embed_tokens.weight', 'model.decoder.embed_positions.weight', 'model.decoder.layers.0.self_attn.k_proj.weight', 'model.decoder.layers.0.self_attn.k_proj.bias', 'model.decoder.layers.0.self_attn.v_proj.weight', 'model.decoder.layers.0.self_attn.v_proj.bias', 'model.decoder.layers.0.self_attn.q_proj.weight', 'model.decoder.layers.0.self_attn.q_proj.bias', 'model.decoder.layers.0.self_attn.out_proj.weight', 'model.decoder.layers.0.self_attn.out_proj.bias', 'model.decoder.layers.0.self_attn_layer_norm.weight', 'model.decoder.layers.0.self_attn_layer_norm.bias', 'model.decoder.layers.0.encoder_attn.k_proj.weight', 'model.decoder.layers.0.encoder_attn.k_proj.bias', 'model.decoder.layers.0.encoder_attn.v_proj.weight', 'model.decoder.layers.0.encoder_attn.v_proj.bias', 'model.decoder.layers.0.encoder_attn.q_proj.weight', 'model.decoder.layers.0.encoder_attn.q_proj.bias', 'model.decoder.layers.0.encoder_attn.out_proj.weight', 'model.decoder.layers.0.encoder_attn.out_proj.bias', 'model.decoder.layers.0.encoder_attn_layer_norm.weight', 'model.decoder.layers.0.encoder_attn_layer_norm.bias', 'model.decoder.layers.0.fc1.weight', 'model.decoder.layers.0.fc1.bias', 'model.decoder.layers.0.fc2.weight', 'model.decoder.layers.0.fc2.bias', 'model.decoder.layers.0.final_layer_norm.weight', 'model.decoder.layers.0.final_layer_norm.bias', 'model.decoder.layers.1.self_attn.k_proj.weight', 'model.decoder.layers.1.self_attn.k_proj.bias', 'model.decoder.layers.1.self_attn.v_proj.weight', 'model.decoder.layers.1.self_attn.v_proj.bias', 'model.decoder.layers.1.self_attn.q_proj.weight', 'model.decoder.layers.1.self_attn.q_proj.bias', 'model.decoder.layers.1.self_attn.out_proj.weight', 'model.decoder.layers.1.self_attn.out_proj.bias', 'model.decoder.layers.1.self_attn_layer_norm.weight', 'model.decoder.layers.1.self_attn_layer_norm.bias', 'model.decoder.layers.1.encoder_attn.k_proj.weight', 'model.decoder.layers.1.encoder_attn.k_proj.bias', 'model.decoder.layers.1.encoder_attn.v_proj.weight', 'model.decoder.layers.1.encoder_attn.v_proj.bias', 'model.decoder.layers.1.encoder_attn.q_proj.weight', 'model.decoder.layers.1.encoder_attn.q_proj.bias', 'model.decoder.layers.1.encoder_attn.out_proj.weight', 'model.decoder.layers.1.encoder_attn.out_proj.bias', 'model.decoder.layers.1.encoder_attn_layer_norm.weight', 'model.decoder.layers.1.encoder_attn_layer_norm.bias', 'model.decoder.layers.1.fc1.weight', 'model.decoder.layers.1.fc1.bias', 'model.decoder.layers.1.fc2.weight', 'model.decoder.layers.1.fc2.bias', 'model.decoder.layers.1.final_layer_norm.weight', 'model.decoder.layers.1.final_layer_norm.bias', 'model.decoder.layers.2.self_attn.k_proj.weight', 'model.decoder.layers.2.self_attn.k_proj.bias', 'model.decoder.layers.2.self_attn.v_proj.weight', 'model.decoder.layers.2.self_attn.v_proj.bias', 'model.decoder.layers.2.self_attn.q_proj.weight', 'model.decoder.layers.2.self_attn.q_proj.bias', 'model.decoder.layers.2.self_attn.out_proj.weight', 'model.decoder.layers.2.self_attn.out_proj.bias', 'model.decoder.layers.2.self_attn_layer_norm.weight', 'model.decoder.layers.2.self_attn_layer_norm.bias', 'model.decoder.layers.2.encoder_attn.k_proj.weight', 'model.decoder.layers.2.encoder_attn.k_proj.bias', 'model.decoder.layers.2.encoder_attn.v_proj.weight', 'model.decoder.layers.2.encoder_attn.v_proj.bias', 'model.decoder.layers.2.encoder_attn.q_proj.weight', 'model.decoder.layers.2.encoder_attn.q_proj.bias', 'model.decoder.layers.2.encoder_attn.out_proj.weight', 'model.decoder.layers.2.encoder_attn.out_proj.bias', 'model.decoder.layers.2.encoder_attn_layer_norm.weight', 'model.decoder.layers.2.encoder_attn_layer_norm.bias', 'model.decoder.layers.2.fc1.weight', 'model.decoder.layers.2.fc1.bias', 'model.decoder.layers.2.fc2.weight', 'model.decoder.layers.2.fc2.bias', 'model.decoder.layers.2.final_layer_norm.weight', 'model.decoder.layers.2.final_layer_norm.bias', 'model.decoder.layers.3.self_attn.k_proj.weight', 'model.decoder.layers.3.self_attn.k_proj.bias', 'model.decoder.layers.3.self_attn.v_proj.weight', 'model.decoder.layers.3.self_attn.v_proj.bias', 'model.decoder.layers.3.self_attn.q_proj.weight', 'model.decoder.layers.3.self_attn.q_proj.bias', 'model.decoder.layers.3.self_attn.out_proj.weight', 'model.decoder.layers.3.self_attn.out_proj.bias', 'model.decoder.layers.3.self_attn_layer_norm.weight', 'model.decoder.layers.3.self_attn_layer_norm.bias', 'model.decoder.layers.3.encoder_attn.k_proj.weight', 'model.decoder.layers.3.encoder_attn.k_proj.bias', 'model.decoder.layers.3.encoder_attn.v_proj.weight', 'model.decoder.layers.3.encoder_attn.v_proj.bias', 'model.decoder.layers.3.encoder_attn.q_proj.weight', 'model.decoder.layers.3.encoder_attn.q_proj.bias', 'model.decoder.layers.3.encoder_attn.out_proj.weight', 'model.decoder.layers.3.encoder_attn.out_proj.bias', 'model.decoder.layers.3.encoder_attn_layer_norm.weight', 'model.decoder.layers.3.encoder_attn_layer_norm.bias', 'model.decoder.layers.3.fc1.weight', 'model.decoder.layers.3.fc1.bias', 'model.decoder.layers.3.fc2.weight', 'model.decoder.layers.3.fc2.bias', 'model.decoder.layers.3.final_layer_norm.weight', 'model.decoder.layers.3.final_layer_norm.bias', 'model.decoder.layers.4.self_attn.k_proj.weight', 'model.decoder.layers.4.self_attn.k_proj.bias', 'model.decoder.layers.4.self_attn.v_proj.weight', 'model.decoder.layers.4.self_attn.v_proj.bias', 'model.decoder.layers.4.self_attn.q_proj.weight', 'model.decoder.layers.4.self_attn.q_proj.bias', 'model.decoder.layers.4.self_attn.out_proj.weight', 'model.decoder.layers.4.self_attn.out_proj.bias', 'model.decoder.layers.4.self_attn_layer_norm.weight', 'model.decoder.layers.4.self_attn_layer_norm.bias', 'model.decoder.layers.4.encoder_attn.k_proj.weight', 'model.decoder.layers.4.encoder_attn.k_proj.bias', 'model.decoder.layers.4.encoder_attn.v_proj.weight', 'model.decoder.layers.4.encoder_attn.v_proj.bias', 'model.decoder.layers.4.encoder_attn.q_proj.weight', 'model.decoder.layers.4.encoder_attn.q_proj.bias', 'model.decoder.layers.4.encoder_attn.out_proj.weight', 'model.decoder.layers.4.encoder_attn.out_proj.bias', 'model.decoder.layers.4.encoder_attn_layer_norm.weight', 'model.decoder.layers.4.encoder_attn_layer_norm.bias', 'model.decoder.layers.4.fc1.weight', 'model.decoder.layers.4.fc1.bias', 'model.decoder.layers.4.fc2.weight', 'model.decoder.layers.4.fc2.bias', 'model.decoder.layers.4.final_layer_norm.weight', 'model.decoder.layers.4.final_layer_norm.bias', 'model.decoder.layers.5.self_attn.k_proj.weight', 'model.decoder.layers.5.self_attn.k_proj.bias', 'model.decoder.layers.5.self_attn.v_proj.weight', 'model.decoder.layers.5.self_attn.v_proj.bias', 'model.decoder.layers.5.self_attn.q_proj.weight', 'model.decoder.layers.5.self_attn.q_proj.bias', 'model.decoder.layers.5.self_attn.out_proj.weight', 'model.decoder.layers.5.self_attn.out_proj.bias', 'model.decoder.layers.5.self_attn_layer_norm.weight', 'model.decoder.layers.5.self_attn_layer_norm.bias', 'model.decoder.layers.5.encoder_attn.k_proj.weight', 'model.decoder.layers.5.encoder_attn.k_proj.bias', 'model.decoder.layers.5.encoder_attn.v_proj.weight', 'model.decoder.layers.5.encoder_attn.v_proj.bias', 'model.decoder.layers.5.encoder_attn.q_proj.weight', 'model.decoder.layers.5.encoder_attn.q_proj.bias', 'model.decoder.layers.5.encoder_attn.out_proj.weight', 'model.decoder.layers.5.encoder_attn.out_proj.bias', 'model.decoder.layers.5.encoder_attn_layer_norm.weight', 'model.decoder.layers.5.encoder_attn_layer_norm.bias', 'model.decoder.layers.5.fc1.weight', 'model.decoder.layers.5.fc1.bias', 'model.decoder.layers.5.fc2.weight', 'model.decoder.layers.5.fc2.bias', 'model.decoder.layers.5.final_layer_norm.weight', 'model.decoder.layers.5.final_layer_norm.bias', 'model.decoder.layernorm_embedding.weight', 'model.decoder.layernorm_embedding.bias']
    - This IS expected if you are initializing BartEXP from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BartEXP from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    Some weights of BartEXP were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['model.bart.final_logits_bias', 'model.bart.model.shared.weight', 'model.bart.model.encoder.embed_tokens.weight', 'model.bart.model.encoder.embed_positions.weight', 'model.bart.model.encoder.layers.0.self_attn.k_proj.weight', 'model.bart.model.encoder.layers.0.self_attn.k_proj.bias', 'model.bart.model.encoder.layers.0.self_attn.v_proj.weight', 'model.bart.model.encoder.layers.0.self_attn.v_proj.bias', 'model.bart.model.encoder.layers.0.self_attn.q_proj.weight', 'model.bart.model.encoder.layers.0.self_attn.q_proj.bias', 'model.bart.model.encoder.layers.0.self_attn.out_proj.weight', 'model.bart.model.encoder.layers.0.self_attn.out_proj.bias', 'model.bart.model.encoder.layers.0.self_attn_layer_norm.weight', 'model.bart.model.encoder.layers.0.self_attn_layer_norm.bias', 'model.bart.model.encoder.layers.0.fc1.weight', 'model.bart.model.encoder.layers.0.fc1.bias', 'model.bart.model.encoder.layers.0.fc2.weight', 'model.bart.model.encoder.layers.0.fc2.bias', 'model.bart.model.encoder.layers.0.final_layer_norm.weight', 'model.bart.model.encoder.layers.0.final_layer_norm.bias', 'model.bart.model.encoder.layers.1.self_attn.k_proj.weight', 'model.bart.model.encoder.layers.1.self_attn.k_proj.bias', 'model.bart.model.encoder.layers.1.self_attn.v_proj.weight', 'model.bart.model.encoder.layers.1.self_attn.v_proj.bias', 'model.bart.model.encoder.layers.1.self_attn.q_proj.weight', 'model.bart.model.encoder.layers.1.self_attn.q_proj.bias', 'model.bart.model.encoder.layers.1.self_attn.out_proj.weight', 'model.bart.model.encoder.layers.1.self_attn.out_proj.bias', 'model.bart.model.encoder.layers.1.self_attn_layer_norm.weight', 'model.bart.model.encoder.layers.1.self_attn_layer_norm.bias', 'model.bart.model.encoder.layers.1.fc1.weight', 'model.bart.model.encoder.layers.1.fc1.bias', 'model.bart.model.encoder.layers.1.fc2.weight', 'model.bart.model.encoder.layers.1.fc2.bias', 'model.bart.model.encoder.layers.1.final_layer_norm.weight', 'model.bart.model.encoder.layers.1.final_layer_norm.bias', 'model.bart.model.encoder.layers.2.self_attn.k_proj.weight', 'model.bart.model.encoder.layers.2.self_attn.k_proj.bias', 'model.bart.model.encoder.layers.2.self_attn.v_proj.weight', 'model.bart.model.encoder.layers.2.self_attn.v_proj.bias', 'model.bart.model.encoder.layers.2.self_attn.q_proj.weight', 'model.bart.model.encoder.layers.2.self_attn.q_proj.bias', 'model.bart.model.encoder.layers.2.self_attn.out_proj.weight', 'model.bart.model.encoder.layers.2.self_attn.out_proj.bias', 'model.bart.model.encoder.layers.2.self_attn_layer_norm.weight', 'model.bart.model.encoder.layers.2.self_attn_layer_norm.bias', 'model.bart.model.encoder.layers.2.fc1.weight', 'model.bart.model.encoder.layers.2.fc1.bias', 'model.bart.model.encoder.layers.2.fc2.weight', 'model.bart.model.encoder.layers.2.fc2.bias', 'model.bart.model.encoder.layers.2.final_layer_norm.weight', 'model.bart.model.encoder.layers.2.final_layer_norm.bias', 'model.bart.model.encoder.layers.3.self_attn.k_proj.weight', 'model.bart.model.encoder.layers.3.self_attn.k_proj.bias', 'model.bart.model.encoder.layers.3.self_attn.v_proj.weight', 'model.bart.model.encoder.layers.3.self_attn.v_proj.bias', 'model.bart.model.encoder.layers.3.self_attn.q_proj.weight', 'model.bart.model.encoder.layers.3.self_attn.q_proj.bias', 'model.bart.model.encoder.layers.3.self_attn.out_proj.weight', 'model.bart.model.encoder.layers.3.self_attn.out_proj.bias', 'model.bart.model.encoder.layers.3.self_attn_layer_norm.weight', 'model.bart.model.encoder.layers.3.self_attn_layer_norm.bias', 'model.bart.model.encoder.layers.3.fc1.weight', 'model.bart.model.encoder.layers.3.fc1.bias', 'model.bart.model.encoder.layers.3.fc2.weight', 'model.bart.model.encoder.layers.3.fc2.bias', 'model.bart.model.encoder.layers.3.final_layer_norm.weight', 'model.bart.model.encoder.layers.3.final_layer_norm.bias', 'model.bart.model.encoder.layers.4.self_attn.k_proj.weight', 'model.bart.model.encoder.layers.4.self_attn.k_proj.bias', 'model.bart.model.encoder.layers.4.self_attn.v_proj.weight', 'model.bart.model.encoder.layers.4.self_attn.v_proj.bias', 'model.bart.model.encoder.layers.4.self_attn.q_proj.weight', 'model.bart.model.encoder.layers.4.self_attn.q_proj.bias', 'model.bart.model.encoder.layers.4.self_attn.out_proj.weight', 'model.bart.model.encoder.layers.4.self_attn.out_proj.bias', 'model.bart.model.encoder.layers.4.self_attn_layer_norm.weight', 'model.bart.model.encoder.layers.4.self_attn_layer_norm.bias', 'model.bart.model.encoder.layers.4.fc1.weight', 'model.bart.model.encoder.layers.4.fc1.bias', 'model.bart.model.encoder.layers.4.fc2.weight', 'model.bart.model.encoder.layers.4.fc2.bias', 'model.bart.model.encoder.layers.4.final_layer_norm.weight', 'model.bart.model.encoder.layers.4.final_layer_norm.bias', 'model.bart.model.encoder.layers.5.self_attn.k_proj.weight', 'model.bart.model.encoder.layers.5.self_attn.k_proj.bias', 'model.bart.model.encoder.layers.5.self_attn.v_proj.weight', 'model.bart.model.encoder.layers.5.self_attn.v_proj.bias', 'model.bart.model.encoder.layers.5.self_attn.q_proj.weight', 'model.bart.model.encoder.layers.5.self_attn.q_proj.bias', 'model.bart.model.encoder.layers.5.self_attn.out_proj.weight', 'model.bart.model.encoder.layers.5.self_attn.out_proj.bias', 'model.bart.model.encoder.layers.5.self_attn_layer_norm.weight', 'model.bart.model.encoder.layers.5.self_attn_layer_norm.bias', 'model.bart.model.encoder.layers.5.fc1.weight', 'model.bart.model.encoder.layers.5.fc1.bias', 'model.bart.model.encoder.layers.5.fc2.weight', 'model.bart.model.encoder.layers.5.fc2.bias', 'model.bart.model.encoder.layers.5.final_layer_norm.weight', 'model.bart.model.encoder.layers.5.final_layer_norm.bias', 'model.bart.model.encoder.layernorm_embedding.weight', 'model.bart.model.encoder.layernorm_embedding.bias', 'model.bart.model.decoder.embed_tokens.weight', 'model.bart.model.decoder.embed_positions.weight', 'model.bart.model.decoder.layers.0.self_attn.k_proj.weight', 'model.bart.model.decoder.layers.0.self_attn.k_proj.bias', 'model.bart.model.decoder.layers.0.self_attn.v_proj.weight', 'model.bart.model.decoder.layers.0.self_attn.v_proj.bias', 'model.bart.model.decoder.layers.0.self_attn.q_proj.weight', 'model.bart.model.decoder.layers.0.self_attn.q_proj.bias', 'model.bart.model.decoder.layers.0.self_attn.out_proj.weight', 'model.bart.model.decoder.layers.0.self_attn.out_proj.bias', 'model.bart.model.decoder.layers.0.self_attn_layer_norm.weight', 'model.bart.model.decoder.layers.0.self_attn_layer_norm.bias', 'model.bart.model.decoder.layers.0.encoder_attn.k_proj.weight', 'model.bart.model.decoder.layers.0.encoder_attn.k_proj.bias', 'model.bart.model.decoder.layers.0.encoder_attn.v_proj.weight', 'model.bart.model.decoder.layers.0.encoder_attn.v_proj.bias', 'model.bart.model.decoder.layers.0.encoder_attn.q_proj.weight', 'model.bart.model.decoder.layers.0.encoder_attn.q_proj.bias', 'model.bart.model.decoder.layers.0.encoder_attn.out_proj.weight', 'model.bart.model.decoder.layers.0.encoder_attn.out_proj.bias', 'model.bart.model.decoder.layers.0.encoder_attn_layer_norm.weight', 'model.bart.model.decoder.layers.0.encoder_attn_layer_norm.bias', 'model.bart.model.decoder.layers.0.fc1.weight', 'model.bart.model.decoder.layers.0.fc1.bias', 'model.bart.model.decoder.layers.0.fc2.weight', 'model.bart.model.decoder.layers.0.fc2.bias', 'model.bart.model.decoder.layers.0.final_layer_norm.weight', 'model.bart.model.decoder.layers.0.final_layer_norm.bias', 'model.bart.model.decoder.layers.1.self_attn.k_proj.weight', 'model.bart.model.decoder.layers.1.self_attn.k_proj.bias', 'model.bart.model.decoder.layers.1.self_attn.v_proj.weight', 'model.bart.model.decoder.layers.1.self_attn.v_proj.bias', 'model.bart.model.decoder.layers.1.self_attn.q_proj.weight', 'model.bart.model.decoder.layers.1.self_attn.q_proj.bias', 'model.bart.model.decoder.layers.1.self_attn.out_proj.weight', 'model.bart.model.decoder.layers.1.self_attn.out_proj.bias', 'model.bart.model.decoder.layers.1.self_attn_layer_norm.weight', 'model.bart.model.decoder.layers.1.self_attn_layer_norm.bias', 'model.bart.model.decoder.layers.1.encoder_attn.k_proj.weight', 'model.bart.model.decoder.layers.1.encoder_attn.k_proj.bias', 'model.bart.model.decoder.layers.1.encoder_attn.v_proj.weight', 'model.bart.model.decoder.layers.1.encoder_attn.v_proj.bias', 'model.bart.model.decoder.layers.1.encoder_attn.q_proj.weight', 'model.bart.model.decoder.layers.1.encoder_attn.q_proj.bias', 'model.bart.model.decoder.layers.1.encoder_attn.out_proj.weight', 'model.bart.model.decoder.layers.1.encoder_attn.out_proj.bias', 'model.bart.model.decoder.layers.1.encoder_attn_layer_norm.weight', 'model.bart.model.decoder.layers.1.encoder_attn_layer_norm.bias', 'model.bart.model.decoder.layers.1.fc1.weight', 'model.bart.model.decoder.layers.1.fc1.bias', 'model.bart.model.decoder.layers.1.fc2.weight', 'model.bart.model.decoder.layers.1.fc2.bias', 'model.bart.model.decoder.layers.1.final_layer_norm.weight', 'model.bart.model.decoder.layers.1.final_layer_norm.bias', 'model.bart.model.decoder.layers.2.self_attn.k_proj.weight', 'model.bart.model.decoder.layers.2.self_attn.k_proj.bias', 'model.bart.model.decoder.layers.2.self_attn.v_proj.weight', 'model.bart.model.decoder.layers.2.self_attn.v_proj.bias', 'model.bart.model.decoder.layers.2.self_attn.q_proj.weight', 'model.bart.model.decoder.layers.2.self_attn.q_proj.bias', 'model.bart.model.decoder.layers.2.self_attn.out_proj.weight', 'model.bart.model.decoder.layers.2.self_attn.out_proj.bias', 'model.bart.model.decoder.layers.2.self_attn_layer_norm.weight', 'model.bart.model.decoder.layers.2.self_attn_layer_norm.bias', 'model.bart.model.decoder.layers.2.encoder_attn.k_proj.weight', 'model.bart.model.decoder.layers.2.encoder_attn.k_proj.bias', 'model.bart.model.decoder.layers.2.encoder_attn.v_proj.weight', 'model.bart.model.decoder.layers.2.encoder_attn.v_proj.bias', 'model.bart.model.decoder.layers.2.encoder_attn.q_proj.weight', 'model.bart.model.decoder.layers.2.encoder_attn.q_proj.bias', 'model.bart.model.decoder.layers.2.encoder_attn.out_proj.weight', 'model.bart.model.decoder.layers.2.encoder_attn.out_proj.bias', 'model.bart.model.decoder.layers.2.encoder_attn_layer_norm.weight', 'model.bart.model.decoder.layers.2.encoder_attn_layer_norm.bias', 'model.bart.model.decoder.layers.2.fc1.weight', 'model.bart.model.decoder.layers.2.fc1.bias', 'model.bart.model.decoder.layers.2.fc2.weight', 'model.bart.model.decoder.layers.2.fc2.bias', 'model.bart.model.decoder.layers.2.final_layer_norm.weight', 'model.bart.model.decoder.layers.2.final_layer_norm.bias', 'model.bart.model.decoder.layers.3.self_attn.k_proj.weight', 'model.bart.model.decoder.layers.3.self_attn.k_proj.bias', 'model.bart.model.decoder.layers.3.self_attn.v_proj.weight', 'model.bart.model.decoder.layers.3.self_attn.v_proj.bias', 'model.bart.model.decoder.layers.3.self_attn.q_proj.weight', 'model.bart.model.decoder.layers.3.self_attn.q_proj.bias', 'model.bart.model.decoder.layers.3.self_attn.out_proj.weight', 'model.bart.model.decoder.layers.3.self_attn.out_proj.bias', 'model.bart.model.decoder.layers.3.self_attn_layer_norm.weight', 'model.bart.model.decoder.layers.3.self_attn_layer_norm.bias', 'model.bart.model.decoder.layers.3.encoder_attn.k_proj.weight', 'model.bart.model.decoder.layers.3.encoder_attn.k_proj.bias', 'model.bart.model.decoder.layers.3.encoder_attn.v_proj.weight', 'model.bart.model.decoder.layers.3.encoder_attn.v_proj.bias', 'model.bart.model.decoder.layers.3.encoder_attn.q_proj.weight', 'model.bart.model.decoder.layers.3.encoder_attn.q_proj.bias', 'model.bart.model.decoder.layers.3.encoder_attn.out_proj.weight', 'model.bart.model.decoder.layers.3.encoder_attn.out_proj.bias', 'model.bart.model.decoder.layers.3.encoder_attn_layer_norm.weight', 'model.bart.model.decoder.layers.3.encoder_attn_layer_norm.bias', 'model.bart.model.decoder.layers.3.fc1.weight', 'model.bart.model.decoder.layers.3.fc1.bias', 'model.bart.model.decoder.layers.3.fc2.weight', 'model.bart.model.decoder.layers.3.fc2.bias', 'model.bart.model.decoder.layers.3.final_layer_norm.weight', 'model.bart.model.decoder.layers.3.final_layer_norm.bias', 'model.bart.model.decoder.layers.4.self_attn.k_proj.weight', 'model.bart.model.decoder.layers.4.self_attn.k_proj.bias', 'model.bart.model.decoder.layers.4.self_attn.v_proj.weight', 'model.bart.model.decoder.layers.4.self_attn.v_proj.bias', 'model.bart.model.decoder.layers.4.self_attn.q_proj.weight', 'model.bart.model.decoder.layers.4.self_attn.q_proj.bias', 'model.bart.model.decoder.layers.4.self_attn.out_proj.weight', 'model.bart.model.decoder.layers.4.self_attn.out_proj.bias', 'model.bart.model.decoder.layers.4.self_attn_layer_norm.weight', 'model.bart.model.decoder.layers.4.self_attn_layer_norm.bias', 'model.bart.model.decoder.layers.4.encoder_attn.k_proj.weight', 'model.bart.model.decoder.layers.4.encoder_attn.k_proj.bias', 'model.bart.model.decoder.layers.4.encoder_attn.v_proj.weight', 'model.bart.model.decoder.layers.4.encoder_attn.v_proj.bias', 'model.bart.model.decoder.layers.4.encoder_attn.q_proj.weight', 'model.bart.model.decoder.layers.4.encoder_attn.q_proj.bias', 'model.bart.model.decoder.layers.4.encoder_attn.out_proj.weight', 'model.bart.model.decoder.layers.4.encoder_attn.out_proj.bias', 'model.bart.model.decoder.layers.4.encoder_attn_layer_norm.weight', 'model.bart.model.decoder.layers.4.encoder_attn_layer_norm.bias', 'model.bart.model.decoder.layers.4.fc1.weight', 'model.bart.model.decoder.layers.4.fc1.bias', 'model.bart.model.decoder.layers.4.fc2.weight', 'model.bart.model.decoder.layers.4.fc2.bias', 'model.bart.model.decoder.layers.4.final_layer_norm.weight', 'model.bart.model.decoder.layers.4.final_layer_norm.bias', 'model.bart.model.decoder.layers.5.self_attn.k_proj.weight', 'model.bart.model.decoder.layers.5.self_attn.k_proj.bias', 'model.bart.model.decoder.layers.5.self_attn.v_proj.weight', 'model.bart.model.decoder.layers.5.self_attn.v_proj.bias', 'model.bart.model.decoder.layers.5.self_attn.q_proj.weight', 'model.bart.model.decoder.layers.5.self_attn.q_proj.bias', 'model.bart.model.decoder.layers.5.self_attn.out_proj.weight', 'model.bart.model.decoder.layers.5.self_attn.out_proj.bias', 'model.bart.model.decoder.layers.5.self_attn_layer_norm.weight', 'model.bart.model.decoder.layers.5.self_attn_layer_norm.bias', 'model.bart.model.decoder.layers.5.encoder_attn.k_proj.weight', 'model.bart.model.decoder.layers.5.encoder_attn.k_proj.bias', 'model.bart.model.decoder.layers.5.encoder_attn.v_proj.weight', 'model.bart.model.decoder.layers.5.encoder_attn.v_proj.bias', 'model.bart.model.decoder.layers.5.encoder_attn.q_proj.weight', 'model.bart.model.decoder.layers.5.encoder_attn.q_proj.bias', 'model.bart.model.decoder.layers.5.encoder_attn.out_proj.weight', 'model.bart.model.decoder.layers.5.encoder_attn.out_proj.bias', 'model.bart.model.decoder.layers.5.encoder_attn_layer_norm.weight', 'model.bart.model.decoder.layers.5.encoder_attn_layer_norm.bias', 'model.bart.model.decoder.layers.5.fc1.weight', 'model.bart.model.decoder.layers.5.fc1.bias', 'model.bart.model.decoder.layers.5.fc2.weight', 'model.bart.model.decoder.layers.5.fc2.bias', 'model.bart.model.decoder.layers.5.final_layer_norm.weight', 'model.bart.model.decoder.layers.5.final_layer_norm.bias', 'model.bart.model.decoder.layernorm_embedding.weight', 'model.bart.model.decoder.layernorm_embedding.bias', 'model.bart.lm_head.weight']
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

I am aiming to use BartEXP for some modification over the output (adding another head and using multiple losses) so I need it to be in that format… How do I need to construct the class so I will inherit BartForConditionalGeneration and use it’s forward / overwrite it?

Hi @latent

this won’t work since it changes the module structure, you can init BartForConditionalGeneration inside BartExp by calling from_pretrained, or you can create a class like BartForConditionalGeneration and your custom layers inside that class. You can see how BartForConditionalGeneration is implemented here

You could modify that class easily.

@valhalla Thanks, something like:

class BartExp(nn.Module):
    def __init__(self,config):
        super(BartExp,self).__init__()
        self.bart = BartForConditionalGeneration.from_pretrained(config)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        output = self.bart(input_ids=input_ids,attention_mask=attention_mask,
                     labels=labels,decoder_input_ids=decoder_input_ids,encoder_outputs=encoder_outputs)
        return output

model = BartExp('facebook/bart-base')

and then when I’m inside a training loop:

output = model(input_ids=original_input_ids,attention_mask=original_attention_mask,
              labels=labels,decoder_input_ids=None,encoder_outputs=None)

It does seem to be working… am i doing it right?

1 Like

It looks right, one thing to remember though is, once you train this model you’ll need load weights using torch.load_state_dict, but if you want to leverage from_pretrained then you could modify the BartForConditionalGeneration class to add custom layers.

I am starting (before fine tuning) with:

class Exp(nn.Module):
    def __init__(self,config):
        super(Exp,self).__init__()
        self.bart = BartForConditionalGeneration.from_pretrained(config)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        output = self.bart(input_ids=input_ids,attention_mask=attention_mask,
                     labels=labels,decoder_input_ids=decoder_input_ids,encoder_outputs=encoder_outputs)
        return output

Loading the pre-trained model before fine-tuning :
model = BartExp(‘facebook/bart-base’)

Since the model is being trained by multiple GPUs , I transform model = torch.nn.DataParallel(model)
and then my saving method is

    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    model_to_save.bart.save_pretrained(args.output_dir)
    tokenizer.save_vocabulary(args.output_dir)

So, when loading the trained model , I am doing the same procedure :
model = BartForQaSimplification(args.output_dir)
where output_dir is the path to my pytorch_model.bin file. I am loading the fine-tuned model, not initiating the pre-trained model again , right ?