Thank you for you prompt reply! One more question: after trying to fine tune pretrained [microsoft/trocr-base-stage1] following your notebooks (BTW huge thanks for them!) I get an error during the validation step (error code is below). The problem seems to be that pretrained model [microsoft/trocr-base-stage1] on model.generate (autoregressive computation at validation step) outputs either a dot or an empty string while the ground truth is much longer than that. Any advice on how to solve this issue?
/opt/conda/lib/python3.7/site-packages/jiwer/measures.py in _preprocess(truth, hypothesis, truth_transform, hypothesis_transform)
327 raise ValueError(
328 "number of ground truth inputs ({}) and hypothesis inputs ({}) must match.".format(
--> 329 len(transformed_truth), len(transformed_hypothesis)
330 )
331 )
ValueError: number of ground truth inputs (46) and hypothesis inputs (1) must match.
Here I’m printing the outputs in CER function for the first batch (size 16) for prediction and label before they go into the CER function.
Prediction:
['.', '.', '.', '.', '', '.', '.', '.', '.', 'in.', '.', '.', '.', '.', '.', '.']
Labels:
['their balances, will gain some immediate help.', 'ready for another bustling season. The pessimists', 'between puffs of light cumulus cloud.', 'boycotting the London talks on the', 'but half inquisitively. As is the case in Fanny', 'borders between private and public purse were', 'Cecil frowned in disappointment as he focussed upon the', 'precipitates were weighed on a semi-micro balance', 'Chinese or native forms of mantis.', 'too, are necessary, but a careful examination', 'Richards, with lavish and suitably gaudy', "up on 1960. Britain's business men are right to back", 'It seems likely that a continuous', 'hundred reasons why it would be', 'as much as to the expectant mother, who', 'a muffled bell taps the sky. Here we']
Below is the full error:
ValueError Traceback (most recent call last)
/tmp/ipykernel_1998/1647004826.py in <module>
11 #model = TrOCR_Image_to_Text()
12 trainer = Trainer(gpus=1, max_epochs=5)
---> 13 trainer.fit(model, datamodule=dataset)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
734 train_dataloaders = train_dataloader
735 self._call_and_handle_interrupt(
--> 736 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
737 )
738
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
680 """
681 try:
--> 682 return trainer_fn(*args, **kwargs)
683 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
684 except KeyboardInterrupt as exception:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
768 # TODO: ckpt_path only in v1.7
769 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 770 self._run(model, ckpt_path=ckpt_path)
771
772 assert self.state.stopped
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
1191
1192 # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1193 self._dispatch()
1194
1195 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
1270 self.training_type_plugin.start_predicting(self)
1271 else:
-> 1272 self.training_type_plugin.start_training(self)
1273
1274 def run_stage(self):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
200 def start_training(self, trainer: "pl.Trainer") -> None:
201 # double dispatch to initiate the training loop
--> 202 self._results = trainer.run_stage()
203
204 def start_evaluating(self, trainer: "pl.Trainer") -> None:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
1280 if self.predicting:
1281 return self._run_predict()
-> 1282 return self._run_train()
1283
1284 def _pre_training_routine(self):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
1302 self.progress_bar_callback.disable()
1303
-> 1304 self._run_sanity_check(self.lightning_module)
1305
1306 # enable train mode
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_sanity_check(self, ref_model)
1366 # run eval step
1367 with torch.no_grad():
-> 1368 self._evaluation_loop.run()
1369
1370 self.call_hook("on_sanity_check_end")
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in advance(self, *args, **kwargs)
107 dl_max_batches = self._max_batches[dataloader_idx]
108
--> 109 dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
110
111 # store batch level output per dataloader
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in advance(self, data_fetcher, dataloader_idx, dl_max_batches, num_dataloaders)
121 # lightning module methods
122 with self.trainer.profiler.profile("evaluation_step_and_end"):
--> 123 output = self._evaluation_step(batch, batch_idx, dataloader_idx)
124 output = self._evaluation_step_end(output)
125
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in _evaluation_step(self, batch, batch_idx, dataloader_idx)
213 self.trainer.lightning_module._current_fx_name = "validation_step"
214 with self.trainer.profiler.profile("validation_step"):
--> 215 output = self.trainer.accelerator.validation_step(step_kwargs)
216
217 return output
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in validation_step(self, step_kwargs)
234 """
235 with self.precision_plugin.val_step_context():
--> 236 return self.training_type_plugin.validation_step(*step_kwargs.values())
237
238 def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in validation_step(self, *args, **kwargs)
217
218 def validation_step(self, *args, **kwargs):
--> 219 return self.model.validation_step(*args, **kwargs)
220
221 def test_step(self, *args, **kwargs):
~/PDFtoLatexOCR/Modeling/Lightning_Models/TrOCR_HuggingFace_IAM.py in validation_step(self, batch, batch_idx)
101 outputs = model.generate(x.to(device))
102
--> 103 cer = compute_cer(pred_ids=outputs, label_ids=y)
104 valid_cer = cer
105
~/PDFtoLatexOCR/Modeling/Lightning_Models/TrOCR_HuggingFace_IAM.py in compute_cer(pred_ids, label_ids)
30
31
---> 32 cer = cer_metric.compute(predictions=pred_str, references=label_str)
33
34 return cer
/opt/conda/lib/python3.7/site-packages/datasets/metric.py in compute(self, predictions, references, **kwargs)
402 references = self.data["references"]
403 with temp_seed(self.seed):
--> 404 output = self._compute(predictions=predictions, references=references, **kwargs)
405
406 if self.buf_writer is not None:
~/.cache/huggingface/modules/datasets_modules/metrics/cer/4e547cc82fc2e597c84fe25f48ed77e3a9acfd354308fe654ccbc6ea9473a61a/cer.py in _compute(self, predictions, references, concatenate_texts)
132 prediction,
133 truth_transform=cer_transform,
--> 134 hypothesis_transform=cer_transform,
135 )
136 incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
/opt/conda/lib/python3.7/site-packages/jiwer/measures.py in compute_measures(truth, hypothesis, truth_transform, hypothesis_transform, **kwargs)
208 # Preprocess truth and hypothesis
209 truth, hypothesis = _preprocess(
--> 210 truth, hypothesis, truth_transform, hypothesis_transform
211 )
212
/opt/conda/lib/python3.7/site-packages/jiwer/measures.py in _preprocess(truth, hypothesis, truth_transform, hypothesis_transform)
327 raise ValueError(
328 "number of ground truth inputs ({}) and hypothesis inputs ({}) must match.".format(
--> 329 len(transformed_truth), len(transformed_hypothesis)
330 )
331 )
ValueError: number of ground truth inputs (46) and hypothesis inputs (1) must match.