T5 as Decoder for OCR

Since T5 is an Encoder-Decoder model and I want to use it as Decoder only for OCR Task, I created custom T5 Decoder class:

from transformers.models.t5.modeling_t5 import T5PreTrainedModel, T5Stack
import torch
import torch.nn as nn
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput

class T5DecoderOnlyForCausalLM(T5PreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
        self.decoder = T5Stack(config, self.shared)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.is_decoder = True
        config.is_decoder = True
        config.use_cache = False

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        decoder_input_ids=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        head_mask=None,
        use_cache=None,
        cross_attn_head_mask=None,
        past_key_values=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        decoder_outputs = self.decoder(
            input_ids=input_ids,
            attention_mask=decoder_attention_mask,
            past_key_values=past_key_values,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = decoder_outputs.last_hidden_state
        logits = self.lm_head(last_hidden_state)
        hidden_states = decoder_outputs.hidden_states
        past_key_values = decoder_outputs.past_key_values
        attentions = decoder_outputs.attentions
        cross_attentions = decoder_outputs.cross_attentions

        # Assuming you want to return the BaseModelOutputWithPastAndCrossAttentions
        return CausalLMOutputWithCrossAttentions(
            logits = logits,
            past_key_values = past_key_values,
            hidden_states = hidden_states,
            attentions = attentions,
            cross_attentions = cross_attentions
        )


    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        decoder_attention_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids if past is used
        if past_key_values is not None:
            input_ids = input_ids[:, -1:]

        return {
            "input_ids": input_ids,
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

I need to use it for multilingual purpose. Hence I’m using ByT5 Tokenizer:

tokenizer = ByT5Tokenizer.from_pretrained('google/byt5-base')
image_processor=ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
processor = TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer)

This is how I’m defining the model:

encoder = ViTModel.from_pretrained("google/vit-base-patch16-224")
decoder = T5DecoderOnlyForCausalLM.from_pretrained("google/byt5-base")
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)

But for each epoch, the predicted labels are same. Training args and trainer looks like this:

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=False, ##
    fp16_full_eval=False,  # Disable FP16 full evaluation
    output_dir="/content/",
    logging_steps=2,
    save_steps=5,
    eval_steps=5,
)

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=dataset,
    eval_dataset=dataset_val,
    data_collator=default_data_collator,
)

I’m not getting why it is generating same text/label for each epoch.

I think you should be using T5ForConditionalGeneration for decoder:

decoder = T5ForConditionalGeneration.from_pretrained(model_name)

Additionally, you could look into Google/Deplot or Google/Matcha or Pix2Struct for your task.

@Sandy1857 I considered it too but its too an encoder-decoder model, the custom decoder I made is inspired from it only with just the decoder part.
Also the reason behind using T5 with Byt5Tokenizer was to make it multilingual.

@Ishan4041 Got it. Please do post your results once you get it working/converged.