How to fine tune TrOCR model properly?

Machine learning neophyte here, so apologies in advance for a “dumb” question.

I have been trying to build a TrOCR model using the VisionEncoderDecoderModel with a checkpoint ‘microsoft/trocr-base-handwritten’ . I have tried the other ones too, but my fine tuning messes up the model instead of improving it. Wanted to ask for some help with fixing this/understanding what goes wrong. I have been using pytorch lightning for the training/fine tuning. My code is below. Out of the box (with the above checkpoint) model can generate pretty accurate results, but after my training/fine tuning its gets worse instead of better.

Some info: I am fine tuning on IAM dataset. The initial loss, when starting is around 8 and it never goes below 4…

Huge thanks for any help!

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')

class TrOCR_Image_to_Text(pl.LightningModule):
    def __init__(self):
        super().__init__()
        model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') 
        model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
        model.config.pad_token_id = processor.tokenizer.pad_token_id
        model.config.vocab_size = model.config.decoder.vocab_size

        # set beam search parameters
        model.config.eos_token_id = processor.tokenizer.sep_token_id
        model.config.max_length = 89
        model.config.early_stopping = True
        model.config.no_repeat_ngram_size = 3
        model.config.length_penalty = 2.0
        model.config.num_beams = 2

        self.vit = model

    def generate(self, input_ids):
        return self.vit.generate(input_ids)

    def forward(self, batch):
        model = self.vit
        model.to(device)
        x,y = batch
        pixel_values = x
        labels = y

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log("training_loss", loss)

        return loss


    def validation_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log("validation_loss", loss, prog_bar=True, on_epoch=True)

        model = self.vit
        
        x,y = batch

        outputs = model.generate(x.to(device))

        cer = compute_cer(pred_ids=outputs, label_ids=y)
        valid_cer = cer

        self.log("cer", valid_cer, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss = self(batch, batch_idx)

        return loss

 def configure_optimizers(self):
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return eval_dataloader

    def test_dataloader(self):
        return eval_dataloader


'''
CER Metric:

'''
from datasets import load_metric
cer_metric = load_metric("cer")

def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

Hi,

Thanks for your interest in TrOCR! Actually, the checkpoint you are loading (i.e. microsoft/trocr-base-handwritten) is one that is already fine-tuned on the IAM dataset. So I guess further fine-tuning on this dataset is not really helpful.

Instead, it makes sense to start from a pre-trained-only checkpoint (namely microsoft/trocr-base-stage1 or microsoft/trocr-large-stage1), and fine-tune it on the IAM dataset (or another dataset of interest). I illustrate this in my notebooks here.

1 Like

Thank you for you prompt reply! One more question: after trying to fine tune pretrained [microsoft/trocr-base-stage1] following your notebooks (BTW huge thanks for them!) I get an error during the validation step (error code is below). The problem seems to be that pretrained model [microsoft/trocr-base-stage1] on model.generate (autoregressive computation at validation step) outputs either a dot or an empty string while the ground truth is much longer than that. Any advice on how to solve this issue?

/opt/conda/lib/python3.7/site-packages/jiwer/measures.py in _preprocess(truth, hypothesis, truth_transform, hypothesis_transform)
    327         raise ValueError(
    328             "number of ground truth inputs ({}) and hypothesis inputs ({}) must match.".format(
--> 329                 len(transformed_truth), len(transformed_hypothesis)
    330             )
    331         )

ValueError: number of ground truth inputs (46) and hypothesis inputs (1) must match.

Here I’m printing the outputs in CER function for the first batch (size 16) for prediction and label before they go into the CER function.

Prediction:
['.', '.', '.', '.', '', '.', '.', '.', '.', 'in.', '.', '.', '.', '.', '.', '.'] 
Labels:
 ['their balances, will gain some immediate help.', 'ready for another bustling season. The pessimists', 'between puffs of light cumulus cloud.', 'boycotting the London talks on the', 'but half inquisitively. As is the case in Fanny', 'borders between private and public purse were', 'Cecil frowned in disappointment as he focussed upon the', 'precipitates were weighed on a semi-micro balance', 'Chinese or native forms of mantis.', 'too, are necessary, but a careful examination', 'Richards, with lavish and suitably gaudy', "up on 1960. Britain's business men are right to back", 'It seems likely that a continuous', 'hundred reasons why it would be', 'as much as to the expectant mother, who', 'a muffled bell taps the sky. Here we']

Below is the full error:

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_1998/1647004826.py in <module>
     11 #model = TrOCR_Image_to_Text()
     12 trainer = Trainer(gpus=1,  max_epochs=5)
---> 13 trainer.fit(model, datamodule=dataset)

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
    734             train_dataloaders = train_dataloader
    735         self._call_and_handle_interrupt(
--> 736             self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    737         )
    738 

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    680         """
    681         try:
--> 682             return trainer_fn(*args, **kwargs)
    683         # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    684         except KeyboardInterrupt as exception:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    768         # TODO: ckpt_path only in v1.7
    769         ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 770         self._run(model, ckpt_path=ckpt_path)
    771 
    772         assert self.state.stopped

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1191 
   1192         # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1193         self._dispatch()
   1194 
   1195         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
   1270             self.training_type_plugin.start_predicting(self)
   1271         else:
-> 1272             self.training_type_plugin.start_training(self)
   1273 
   1274     def run_stage(self):

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    200     def start_training(self, trainer: "pl.Trainer") -> None:
    201         # double dispatch to initiate the training loop
--> 202         self._results = trainer.run_stage()
    203 
    204     def start_evaluating(self, trainer: "pl.Trainer") -> None:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
   1280         if self.predicting:
   1281             return self._run_predict()
-> 1282         return self._run_train()
   1283 
   1284     def _pre_training_routine(self):

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
   1302             self.progress_bar_callback.disable()
   1303 
-> 1304         self._run_sanity_check(self.lightning_module)
   1305 
   1306         # enable train mode

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_sanity_check(self, ref_model)
   1366             # run eval step
   1367             with torch.no_grad():
-> 1368                 self._evaluation_loop.run()
   1369 
   1370             self.call_hook("on_sanity_check_end")

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    143             try:
    144                 self.on_advance_start(*args, **kwargs)
--> 145                 self.advance(*args, **kwargs)
    146                 self.on_advance_end()
    147                 self.restarting = False

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in advance(self, *args, **kwargs)
    107         dl_max_batches = self._max_batches[dataloader_idx]
    108 
--> 109         dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
    110 
    111         # store batch level output per dataloader

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    143             try:
    144                 self.on_advance_start(*args, **kwargs)
--> 145                 self.advance(*args, **kwargs)
    146                 self.on_advance_end()
    147                 self.restarting = False

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in advance(self, data_fetcher, dataloader_idx, dl_max_batches, num_dataloaders)
    121         # lightning module methods
    122         with self.trainer.profiler.profile("evaluation_step_and_end"):
--> 123             output = self._evaluation_step(batch, batch_idx, dataloader_idx)
    124             output = self._evaluation_step_end(output)
    125 

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in _evaluation_step(self, batch, batch_idx, dataloader_idx)
    213             self.trainer.lightning_module._current_fx_name = "validation_step"
    214             with self.trainer.profiler.profile("validation_step"):
--> 215                 output = self.trainer.accelerator.validation_step(step_kwargs)
    216 
    217         return output

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in validation_step(self, step_kwargs)
    234         """
    235         with self.precision_plugin.val_step_context():
--> 236             return self.training_type_plugin.validation_step(*step_kwargs.values())
    237 
    238     def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in validation_step(self, *args, **kwargs)
    217 
    218     def validation_step(self, *args, **kwargs):
--> 219         return self.model.validation_step(*args, **kwargs)
    220 
    221     def test_step(self, *args, **kwargs):

~/PDFtoLatexOCR/Modeling/Lightning_Models/TrOCR_HuggingFace_IAM.py in validation_step(self, batch, batch_idx)
    101         outputs = model.generate(x.to(device))
    102 
--> 103         cer = compute_cer(pred_ids=outputs, label_ids=y)
    104         valid_cer = cer
    105 

~/PDFtoLatexOCR/Modeling/Lightning_Models/TrOCR_HuggingFace_IAM.py in compute_cer(pred_ids, label_ids)
     30 
     31 
---> 32     cer = cer_metric.compute(predictions=pred_str, references=label_str)
     33 
     34     return cer

/opt/conda/lib/python3.7/site-packages/datasets/metric.py in compute(self, predictions, references, **kwargs)
    402             references = self.data["references"]
    403             with temp_seed(self.seed):
--> 404                 output = self._compute(predictions=predictions, references=references, **kwargs)
    405 
    406             if self.buf_writer is not None:

~/.cache/huggingface/modules/datasets_modules/metrics/cer/4e547cc82fc2e597c84fe25f48ed77e3a9acfd354308fe654ccbc6ea9473a61a/cer.py in _compute(self, predictions, references, concatenate_texts)
    132                 prediction,
    133                 truth_transform=cer_transform,
--> 134                 hypothesis_transform=cer_transform,
    135             )
    136             incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]

/opt/conda/lib/python3.7/site-packages/jiwer/measures.py in compute_measures(truth, hypothesis, truth_transform, hypothesis_transform, **kwargs)
    208     # Preprocess truth and hypothesis
    209     truth, hypothesis = _preprocess(
--> 210         truth, hypothesis, truth_transform, hypothesis_transform
    211     )
    212 

/opt/conda/lib/python3.7/site-packages/jiwer/measures.py in _preprocess(truth, hypothesis, truth_transform, hypothesis_transform)
    327         raise ValueError(
    328             "number of ground truth inputs ({}) and hypothesis inputs ({}) must match.".format(
--> 329                 len(transformed_truth), len(transformed_hypothesis)
    330             )
    331         )

ValueError: number of ground truth inputs (46) and hypothesis inputs (1) must match.