Iām trying to make a custom vision encoder-decoder model.
I want to use pre-trained encoder while want to use decoder from scratch, So I cannot use VisionEncoderDecoderModel.from_pretrained()
.
Specifically, I want to use pre-trained deit
model as a encoder, and custom trained Electra
as a decoder. (using AutoModelForCausalLM
)
I write code like below. In train step, there is no problem.
But I got a problem which says āmodel have no attribute āgenerateāā. How can I implement or import generate
function?
class CustomEncoderDecoderModel(nn.Module):
config_class = VisionEncoderDecoderConfig
def __init__(self, encoder_name, decoder_config,
config=None):
super(CustomEncoderDecoderModel, self).__init__()
self.encoder = AutoModel.from_pretrained(encoder_name)
self.decoder_config = decoder_config
self.decoder = AutoModelForCausalLM.from_config(self.decoder_config)
self.config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(self.encoder.config, self.decoder.config)
self.criterion = nn.CrossEntropyLoss()
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
def forward(self, pixel_values, labels, decoder_input_ids=None,
decoder_input_embeds=None,
decoder_attention_mask=None,
decoder_inputs_embeds=None,
past_key_values=None):
encoder_outputs = self.encoder(pixel_values,
output_attentions=True)
encoder_hidden_states = encoder_outputs[0]
encoder_attention_mask = None
if decoder_input_ids is None and decoder_input_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.decoder.config.pad_token_id, decoder_start_token_id=2
)
if self.encoder.config.hidden_size != self.decoder.config.hidden_size:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
decoder_outputs = self.decoder(
input_ids = decoder_input_ids,
attention_mask = decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
output_attentions=True,
use_cache=True,
past_key_values=past_key_values,
)
logits = decoder_outputs[0]
loss = self.criterion(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
return {'loss': loss, 'logits': logits,
'past_key_values': decoder_outputs.past_key_values,
'decoder_hidden_states': decoder_outputs.hidden_states,
'decoder_attentions': decoder_outputs.attentions,
'cross_attentions': decoder_outputs.cross_attentions,
'encoder_hidden_state': encoder_outputs.hidden_states,
'encoder_attentions': encoder_attention_mask,
'encoder_attentions': encoder_outputs.attentions,
}