Iām trying to make a custom vision encoder-decoder model.
I want to use pre-trained encoder while want to use decoder from scratch, So I cannot use VisionEncoderDecoderModel.from_pretrained().
Specifically, I want to use pre-trained deit model as a encoder, and custom trained Electra as a decoder. (using AutoModelForCausalLM)
I write code like below. In train step, there is no problem.
But I got a problem which says āmodel have no attribute āgenerateāā. How can I implement or import generate function?
class CustomEncoderDecoderModel(nn.Module):
    config_class = VisionEncoderDecoderConfig
    def __init__(self, encoder_name, decoder_config,
                 config=None):
        super(CustomEncoderDecoderModel, self).__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        self.decoder_config = decoder_config
        self.decoder = AutoModelForCausalLM.from_config(self.decoder_config)
        self.config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(self.encoder.config, self.decoder.config)
        
        self.criterion = nn.CrossEntropyLoss()
        self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
    def forward(self, pixel_values, labels, decoder_input_ids=None,
                decoder_input_embeds=None,
                decoder_attention_mask=None,
                decoder_inputs_embeds=None,
                past_key_values=None):
        encoder_outputs = self.encoder(pixel_values,
                                       output_attentions=True)
        encoder_hidden_states = encoder_outputs[0]
        encoder_attention_mask = None
        if decoder_input_ids is None and decoder_input_embeds is None:
            decoder_input_ids = shift_tokens_right(
                labels, self.decoder.config.pad_token_id, decoder_start_token_id=2
            )
        if self.encoder.config.hidden_size != self.decoder.config.hidden_size:
            
            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
        decoder_outputs = self.decoder(
            input_ids = decoder_input_ids,
            attention_mask = decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            output_attentions=True,
            use_cache=True,
            past_key_values=past_key_values,
        )
        logits = decoder_outputs[0]
        loss = self.criterion(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
        return {'loss': loss, 'logits': logits,
                'past_key_values': decoder_outputs.past_key_values,
                'decoder_hidden_states': decoder_outputs.hidden_states,
                'decoder_attentions': decoder_outputs.attentions,
                'cross_attentions': decoder_outputs.cross_attentions,
                'encoder_hidden_state': encoder_outputs.hidden_states,
                'encoder_attentions': encoder_attention_mask,
                'encoder_attentions': encoder_outputs.attentions,
                }