Replacing the decoder of an xxxEncoderDecoderModel

A question someone had was how to replace the decoder of an existing VisionEncoderDecoderModel from the hub. Namely, the TrOCR model currently only has checkpoints on the hub with an English-only language model (RoBERTa) as decoder - how to replace it with a multilingual XLMRoBERTa model?

Hereā€™s the answer:

from transformers import VisionEncoderDecoderModel, RobertaForCausalLM

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

# replace decoder
model.decoder = RobertaForCausalLM.from_pretrained("xlm-roberta-base", is_decoder=True, add_cross_attention=True)

As you can see, we first initialize the entire model from the hub, after which we replace the decoder with a custom one. Also note that we are initializing a RobertaForCausalLM model, which includes the language modeling head on top (as opposed to RobertaModel). We also set the is_decoder and add_cross_attention attributes of the modelā€™s configuration to make sure cross-attention is added between the encoder and decoder. A warning will be printed when we initialize the model, indicating that the weights of the cross-attention layers are randomly initialized.

Preparing the data

Also note that, in case you are going to prepare data for the model, one must use the appropriate tokenizer to create the labels for the model (in this case, one should use XLMRobertaTokenizer). Letā€™s define a TrOCRProcessor (which wraps a feature extractor and a tokenizer into a single object), by first using the one from the corresponding checkpoint, and then replace the tokenizer part:

from transformers import TrOCRProcessor, XLMRobertaTokenizer

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
processor.tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")

One can then prepare a single (image, text) pair for the model as follows:

from PIL import Image
import torch

image = Image.open("...").convert("RGB")
text = "..."

pixel_values = processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = processor.tokenizer(text,  padding="max_length", truncation=True).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != processor.tokenizer.pad_token_id else -100 for label in labels]

encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
2 Likes

I also tried to use trocr-small-handwritten encoder and get this error, when tried to predict image by model:
ā€˜VisionEncoderDecoderModelā€™ object has no attribute ā€˜enc_to_dec_projā€™

from transformers import RobertaForCausalLM, VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained(f'microsoft/trocr-small-handwritten')
model.decoder = RobertaForCausalLM.from_pretrained('xlm-roberta-base', is_decoder=True, add_cross_attention=True)