Huggingface Saving `VisionEncoderDecoderModel` to `TorchScript` problem

  • python version: 3.9.12
  • transformers: 4.26.0
  • torch: 2.0.0
  • pillow: 9.2.0

Hi I want to save local checkpoint of Huggingface transformers.VisionEncoderDecoderModel to torchScript via torch.jit.trace from below code:

import torch
from PIL import Image
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
)

processor = TrOCRProcessor.from_pretrained('weights_with_custom_vocab', local_files_only=True)
model = VisionEncoderDecoderModel.from_pretrained('weights_with_custom_vocab', local_files_only=True, torchscript=True)

image_file_name = '00453cb6-5ea8-4988-aa93-dcb8e29719ec.png'
text_file_name = '00453cb6-5ea8-4988-aa93-dcb8e29719ec.txt'

with open(f'OCR_data_small/{text_file_name}', 'r', encoding='utf-8') as f:
    text = f.read()
# prepare image (i.e. resize + normalize)
image = Image.open(f'OCR_data_small/{image_file_name}').convert("RGB")
# crop out whitespaces
pixel_values = processor(image, return_tensors="pt").pixel_values
labels = processor.tokenizer(text, padding="max_length", max_length=20).input_ids
labels = [label if label != processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = pixel_values.squeeze(), torch.tensor(labels)

model.decoder.resize_token_embeddings(len(processor.tokenizer))


model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 20
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

model.eval()

traced_model = torch.jit.trace(model, (encoding[0].unsqueeze(0), encoding[1]))

yet it shows below error:

IndexError                                Traceback (most recent call last)
Cell In [122], line 40
     36 model.config.num_beams = 4
     38 model.eval()
---> 40 traced_model = torch.jit.trace(model, (encoding[0].unsqueeze(0), encoding[1]))

File /opt/conda/lib/python3.9/site-packages/torch/jit/_trace.py:794, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
    792         else:
    793             raise RuntimeError("example_kwarg_inputs should be a dict")
--> 794     return trace_module(
    795         func,
    796         {"forward": example_inputs},
    797         None,
    798         check_trace,
    799         wrap_check_inputs(check_inputs),
    800         check_tolerance,
    801         strict,
    802         _force_outplace,
    803         _module_class,
    804         example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
    805         _store_inputs=_store_inputs
    806     )
    807 if (
    808     hasattr(func, "__self__")
    809     and isinstance(func.__self__, torch.nn.Module)
    810     and func.__name__ == "forward"
    811 ):
    812     if example_inputs is None:

File /opt/conda/lib/python3.9/site-packages/torch/jit/_trace.py:1056, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
   1054 else:
   1055     example_inputs = make_tuple(example_inputs)
-> 1056     module._c._create_method_from_trace(
   1057         method_name,
   1058         func,
   1059         example_inputs,
   1060         var_lookup_fn,
   1061         strict,
   1062         _force_outplace,
   1063         argument_names,
   1064         _store_inputs
   1065     )
   1067 check_trace_method = module._c._get_method(method_name)
   1069 # Check the trace against new traces created from user-specified inputs

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1488, in Module._slow_forward(self, *input, **kwargs)
   1486         recording_scopes = False
   1487 try:
-> 1488     result = self.forward(*input, **kwargs)
   1489 finally:
   1490     if recording_scopes:

File /opt/conda/lib/python3.9/site-packages/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py:609, in VisionEncoderDecoderModel.forward(self, pixel_values, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)
    604     decoder_input_ids = shift_tokens_right(
    605         labels, self.config.pad_token_id, self.config.decoder_start_token_id
    606     )
    608 # Decode
--> 609 decoder_outputs = self.decoder(
    610     input_ids=decoder_input_ids,
    611     attention_mask=decoder_attention_mask,
    612     encoder_hidden_states=encoder_hidden_states,
    613     encoder_attention_mask=encoder_attention_mask,
    614     inputs_embeds=decoder_inputs_embeds,
    615     output_attentions=output_attentions,
    616     output_hidden_states=output_hidden_states,
    617     use_cache=use_cache,
    618     past_key_values=past_key_values,
    619     return_dict=return_dict,
    620     **kwargs_decoder,
    621 )
    623 # Compute loss independent from decoder (as some shift the logits inside them)
    624 loss = None

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1488, in Module._slow_forward(self, *input, **kwargs)
   1486         recording_scopes = False
   1487 try:
-> 1488     result = self.forward(*input, **kwargs)
   1489 finally:
   1490     if recording_scopes:

File /opt/conda/lib/python3.9/site-packages/transformers/models/trocr/modeling_trocr.py:959, in TrOCRForCausalLM.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask, cross_attn_head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    956 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    958 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 959 outputs = self.model.decoder(
    960     input_ids=input_ids,
    961     attention_mask=attention_mask,
    962     encoder_hidden_states=encoder_hidden_states,
    963     encoder_attention_mask=encoder_attention_mask,
    964     head_mask=head_mask,
    965     cross_attn_head_mask=cross_attn_head_mask,
    966     past_key_values=past_key_values,
    967     inputs_embeds=inputs_embeds,
    968     use_cache=use_cache,
    969     output_attentions=output_attentions,
    970     output_hidden_states=output_hidden_states,
    971     return_dict=return_dict,
    972 )
    974 logits = self.output_projection(outputs[0])
    976 loss = None

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1488, in Module._slow_forward(self, *input, **kwargs)
   1486         recording_scopes = False
   1487 try:
-> 1488     result = self.forward(*input, **kwargs)
   1489 finally:
   1490     if recording_scopes:

File /opt/conda/lib/python3.9/site-packages/transformers/models/trocr/modeling_trocr.py:642, in TrOCRDecoder.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask, cross_attn_head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    639 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
    641 if inputs_embeds is None:
--> 642     inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
    644 if self.config.use_learned_position_embeddings:
    645     embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1488, in Module._slow_forward(self, *input, **kwargs)
   1486         recording_scopes = False
   1487 try:
-> 1488     result = self.forward(*input, **kwargs)
   1489 finally:
   1490     if recording_scopes:

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/sparse.py:162, in Embedding.forward(self, input)
    161 def forward(self, input: Tensor) -> Tensor:
--> 162     return F.embedding(
    163         input, self.weight, self.padding_idx, self.max_norm,
    164         self.norm_type, self.scale_grad_by_freq, self.sparse)

File /opt/conda/lib/python3.9/site-packages/torch/nn/functional.py:2210, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2204     # Note [embedding_renorm set_grad_enabled]
   2205     # XXX: equivalent to
   2206     # with torch.no_grad():
   2207     #   torch.embedding_renorm_
   2208     # remove once script supports set_grad_enabled
   2209     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self

I am sure I have resize the embedding of model by tokenizer. Does anyone have a proper idea of saving VisionEncoderDecoderModel or other transformers Seq2seq to TorchScript?