Using EncoderDecoderModel

Hi,
i have tried to combine the ViT (BeiT weights 16patch-384) as Encoder with a Bert Model as Decoder.
(Like Microsofts new arxiv TrOCR paper )


If i use the EncoderDecoderModel it does not support pixel_values for the encoder.

feat_extractor = ViTFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-384")
tokenizer = XLMRobertaTokenizer.from_pretrained("microsoft/Multilingual-MiniLM-L12-H384")

encoder = ViTModel.from_pretrained("microsoft/beit-base-patch16-384", output_attentions=True, output_hidden_states=True, return_dict=True, is_decoder=False)

decoder = BertModel.from_pretrained("microsoft/Multilingual-MiniLM-L12-H384", is_decoder=True, add_cross_attention=True, return_dict=True)

encoder_inputs = feat_extractor(torch.randn(3, 512, 512), return_tensors='pt')
decoder_input = tokenizer.encode_plus("Hello World", return_tensors='pt')
print(decoder_input)

#encoder_outputs = encoder(**encoder_inputs)
#decoder_outputs = decoder(input_ids=decoder_input['input_ids'],encoder_hidden_states=encoder_outputs.last_hidden_state)

model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
outputs = model(input_embeds=encoder_inputs['pixel_values'], decoder_input_ids=decoder_input)

  File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py", line 425, in forward
    encoder_outputs = self.encoder(
  File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'input_ids'

And if i try it to combine with different specified:

feat_extractor = ViTFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-384")

tokenizer = XLMRobertaTokenizer.from_pretrained("microsoft/Multilingual-MiniLM-L12-H384")

encoder = ViTModel.from_pretrained("microsoft/beit-base-patch16-384", output_attentions=True, output_hidden_states=True, return_dict=True, is_decoder=False)

decoder = BertModel.from_pretrained("microsoft/Multilingual-MiniLM-L12-H384", is_decoder=True, add_cross_attention=True, return_dict=True)

encoder_inputs = feat_extractor(torch.randn(3, 512, 512), return_tensors='pt')

decoder_input = tokenizer.encode_plus("Hello World", return_tensors='pt')

print(decoder_input)

encoder_outputs = encoder(**encoder_inputs)

decoder_outputs = decoder(input_ids=decoder_input['input_ids'], encoder_hidden_states=encoder_outputs.last_hidden_state)

#model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

#outputs = model(input_embeds=encoder_inputs['pixel_values'], decoder_input_ids=decoder_input)

  File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 990, in forward
    encoder_outputs = self.encoder(
  File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 582, in forward
    layer_outputs = layer_module(
  File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 494, in forward
    cross_attention_outputs = self.crossattention(
  File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 401, in forward
    self_outputs = self.self(
  File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 280, in forward
    key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 96, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 1847, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (577x768 and 384x384)

Is there currently a way to do this with the transformers lib ?
Else i think its also a way to use the ViT model pretrained from transformers and build the Decoder via nn.TransformerDecoder or from scratch and init with a pretrained

Thanks :slight_smile:

Hi,

EncoderDecoderModel is meant to combine any bidirectional text encoder (e.g. BERT) with any autoregressive text decoder (e.g. GPT2). We’re planning to add a VisionEncoderDecoderModel (recently we’ve added SpeechEncoderDecoderModel, which allows you to combine any speech autoencoding model such as Wav2Vec2 with any autoregressive text decoder).

Feel free to contribute this if you are interested!

Thanks :slight_smile:
Yes i think i will take a look on this into my holidays and also on the contributing “guide”.