A question someone had was how to replace the decoder of an existing VisionEncoderDecoderModel from the hub. Namely, the TrOCR model currently only has checkpoints on the hub with an English-only language model (RoBERTa) as decoder - how to replace it with a multilingual XLMRoBERTa model?
Here’s the answer:
from transformers import VisionEncoderDecoderModel, RobertaForCausalLM
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
# replace decoder
model.decoder = RobertaForCausalLM.from_pretrained("xlm-roberta-base", is_decoder=True, add_cross_attention=True)
As you can see, we first initialize the entire model from the hub, after which we replace the decoder with a custom one. Also note that we are initializing a RobertaForCausalLM
model, which includes the language modeling head on top (as opposed to RobertaModel
). We also set the is_decoder
and add_cross_attention
attributes of the model’s configuration to make sure cross-attention is added between the encoder and decoder. A warning will be printed when we initialize the model, indicating that the weights of the cross-attention layers are randomly initialized.
Preparing the data
Also note that, in case you are going to prepare data for the model, one must use the appropriate tokenizer to create the labels for the model (in this case, one should use XLMRobertaTokenizer
). Let’s define a TrOCRProcessor
(which wraps a feature extractor and a tokenizer into a single object), by first using the one from the corresponding checkpoint, and then replace the tokenizer part:
from transformers import TrOCRProcessor, XLMRobertaTokenizer
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
processor.tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
One can then prepare a single (image, text) pair for the model as follows:
from PIL import Image
import torch
image = Image.open("...").convert("RGB")
text = "..."
pixel_values = processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = processor.tokenizer(text, padding="max_length", truncation=True).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}