Hi,
I am extending BART model to contain two decoders instead of one:
`class BartModelV2(BartModel):
def init(self, config: BartConfig):
super().init(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
self.encoder = BartEncoder(config, self.shared)
self.decoder1 = BartDecoder(config, self.shared)
self.decoder2 = BartDecoder(config, self.shared)
self.init_weights()
`
Now, when I use from_pretrained
function to load the weights from BART model there is a mismatch between the decoder attribute of the original loaded BART model and my new BartModelV2. I intended to load the weights of the decoder of the original BART into both attributes decoder1 and decoder2. Any idea how I could force this behavior?
Thanks a lot for help!