Replacing the decoder of an xxxEncoderDecoderModel

A question someone had was how to replace the decoder of an existing VisionEncoderDecoderModel from the hub. Namely, the TrOCR model currently only has checkpoints on the hub with an English-only language model (RoBERTa) as decoder - how to replace it with a multilingual XLMRoBERTa model?

Here’s the answer:

from transformers import VisionEncoderDecoderModel, RobertaForCausalLM

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

# replace decoder
model.decoder = RobertaForCausalLM.from_pretrained("xlm-roberta-base", is_decoder=True, add_cross_attention=True)

As you can see, we first initialize the entire model from the hub, after which we replace the decoder with a custom one. Also note that we are initializing a RobertaForCausalLM model, which includes the language modeling head on top (as opposed to RobertaModel). We also set the is_decoder and add_cross_attention attributes of the model’s configuration to make sure cross-attention is added between the encoder and decoder. A warning will be printed when we initialize the model, indicating that the weights of the cross-attention layers are randomly initialized.

Preparing the data

Also note that, in case you are going to prepare data for the model, one must use the appropriate tokenizer to create the labels for the model (in this case, one should use XLMRobertaTokenizer). Let’s define a TrOCRProcessor (which wraps a feature extractor and a tokenizer into a single object), by first using the one from the corresponding checkpoint, and then replace the tokenizer part:

from transformers import TrOCRProcessor, XLMRobertaTokenizer

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
processor.tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")

One can then prepare a single (image, text) pair for the model as follows:

from PIL import Image
import torch

image = Image.open("...").convert("RGB")
text = "..."

pixel_values = processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = processor.tokenizer(text,  padding="max_length", truncation=True).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != processor.tokenizer.pad_token_id else -100 for label in labels]

encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
2 Likes

I also tried to use trocr-small-handwritten encoder and get this error, when tried to predict image by model:
‘VisionEncoderDecoderModel’ object has no attribute ‘enc_to_dec_proj’

from transformers import RobertaForCausalLM, VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained(f'microsoft/trocr-small-handwritten')
model.decoder = RobertaForCausalLM.from_pretrained('xlm-roberta-base', is_decoder=True, add_cross_attention=True)

Hi Nielsr, I’ve been training using this approach over the last few days. In my case (i have 50k training pairs), replacing the decoder while keeping the weights from trocr-base-handwritten has resulted in better performance during training, especially in WER, than when using this which seems needing lots of training pairs:

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encoder, decoder)

But I was wondering how exactly we should use the model generated in the way that you propose during inference. Even when I save the processor and model, the predictions on inference seem to be absurd (I suppose it’s a vocabulary alignment problem). For instance, when I switch to a BERT multilingual model, I achieve CER = 0.09 and WER = 0.17 on the development set, which are overall good results using this code:

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
processor.tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
processor.save_pretrained('./processor')

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
model.decoder = AutoModelForCausalLM.from_pretrained("bert-base-multilingual-cased", is_decoder=True, add_cross_attention=True)

model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

model.save_pretrained("trainer_trocr/model")

But when using the trained model on inferences using this:

device=torch.device("cpu")
processor = TrOCRProcessor.from_pretrained("./processor")
model = VisionEncoderDecoderModel.from_pretrained("trainer_trocr/model").to(device)

# load image 
url = "D:\\trocr\\1538983940018_363.png"
image = Image.open(url).convert("RGB")

pixel_values = processor(image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(generated_ids)
print("OCR output:", generated_text)

The model take ages to load and I get results as : död död död 鈦 鈦 鈦ਿਕਿਕਿਕ Première Première Premièreчиличиличили 锢 锢 锢:

I’ve tried different combinations, but I can’t get the model to perform as well as during training. If you have any suggestions, I would be really grateful. Have a nice day, and thank you for all the work you’ve put into trocr.