Hi,
i have tried to combine the ViT (BeiT weights 16patch-384) as Encoder with a Bert Model as Decoder.
(Like Microsofts new arxiv TrOCR paper )
If i use the EncoderDecoderModel it does not support pixel_values for the encoder.
feat_extractor = ViTFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-384")
tokenizer = XLMRobertaTokenizer.from_pretrained("microsoft/Multilingual-MiniLM-L12-H384")
encoder = ViTModel.from_pretrained("microsoft/beit-base-patch16-384", output_attentions=True, output_hidden_states=True, return_dict=True, is_decoder=False)
decoder = BertModel.from_pretrained("microsoft/Multilingual-MiniLM-L12-H384", is_decoder=True, add_cross_attention=True, return_dict=True)
encoder_inputs = feat_extractor(torch.randn(3, 512, 512), return_tensors='pt')
decoder_input = tokenizer.encode_plus("Hello World", return_tensors='pt')
print(decoder_input)
#encoder_outputs = encoder(**encoder_inputs)
#decoder_outputs = decoder(input_ids=decoder_input['input_ids'],encoder_hidden_states=encoder_outputs.last_hidden_state)
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
outputs = model(input_embeds=encoder_inputs['pixel_values'], decoder_input_ids=decoder_input)
File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py", line 425, in forward
encoder_outputs = self.encoder(
File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'input_ids'
And if i try it to combine with different specified:
feat_extractor = ViTFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-384")
tokenizer = XLMRobertaTokenizer.from_pretrained("microsoft/Multilingual-MiniLM-L12-H384")
encoder = ViTModel.from_pretrained("microsoft/beit-base-patch16-384", output_attentions=True, output_hidden_states=True, return_dict=True, is_decoder=False)
decoder = BertModel.from_pretrained("microsoft/Multilingual-MiniLM-L12-H384", is_decoder=True, add_cross_attention=True, return_dict=True)
encoder_inputs = feat_extractor(torch.randn(3, 512, 512), return_tensors='pt')
decoder_input = tokenizer.encode_plus("Hello World", return_tensors='pt')
print(decoder_input)
encoder_outputs = encoder(**encoder_inputs)
decoder_outputs = decoder(input_ids=decoder_input['input_ids'], encoder_hidden_states=encoder_outputs.last_hidden_state)
#model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
#outputs = model(input_embeds=encoder_inputs['pixel_values'], decoder_input_ids=decoder_input)
File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 990, in forward
encoder_outputs = self.encoder(
File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 582, in forward
layer_outputs = layer_module(
File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 494, in forward
cross_attention_outputs = self.crossattention(
File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 401, in forward
self_outputs = self.self(
File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/felix/anaconda3/envs/work/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 280, in forward
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 96, in forward
return F.linear(input, self.weight, self.bias)
File "/home/felix/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 1847, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (577x768 and 384x384)
Is there currently a way to do this with the transformers lib ?
Else i think its also a way to use the ViT model pretrained from transformers and build the Decoder via nn.TransformerDecoder or from scratch and init with a pretrained
Thanks