- python version: 3.9.12
- transformers: 4.26.0
- torch: 2.0.0
- pillow: 9.2.0
Hi I want to save local checkpoint of Huggingface transformers.VisionEncoderDecoderModel
to torchScript via torch.jit.trace
from below code:
import torch
from PIL import Image
from transformers import (
TrOCRProcessor,
VisionEncoderDecoderModel,
)
processor = TrOCRProcessor.from_pretrained('weights_with_custom_vocab', local_files_only=True)
model = VisionEncoderDecoderModel.from_pretrained('weights_with_custom_vocab', local_files_only=True, torchscript=True)
image_file_name = '00453cb6-5ea8-4988-aa93-dcb8e29719ec.png'
text_file_name = '00453cb6-5ea8-4988-aa93-dcb8e29719ec.txt'
with open(f'OCR_data_small/{text_file_name}', 'r', encoding='utf-8') as f:
text = f.read()
# prepare image (i.e. resize + normalize)
image = Image.open(f'OCR_data_small/{image_file_name}').convert("RGB")
# crop out whitespaces
pixel_values = processor(image, return_tensors="pt").pixel_values
labels = processor.tokenizer(text, padding="max_length", max_length=20).input_ids
labels = [label if label != processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = pixel_values.squeeze(), torch.tensor(labels)
model.decoder.resize_token_embeddings(len(processor.tokenizer))
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 20
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
model.eval()
traced_model = torch.jit.trace(model, (encoding[0].unsqueeze(0), encoding[1]))
yet it shows below error:
IndexError Traceback (most recent call last)
Cell In [122], line 40
36 model.config.num_beams = 4
38 model.eval()
---> 40 traced_model = torch.jit.trace(model, (encoding[0].unsqueeze(0), encoding[1]))
File /opt/conda/lib/python3.9/site-packages/torch/jit/_trace.py:794, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
792 else:
793 raise RuntimeError("example_kwarg_inputs should be a dict")
--> 794 return trace_module(
795 func,
796 {"forward": example_inputs},
797 None,
798 check_trace,
799 wrap_check_inputs(check_inputs),
800 check_tolerance,
801 strict,
802 _force_outplace,
803 _module_class,
804 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
805 _store_inputs=_store_inputs
806 )
807 if (
808 hasattr(func, "__self__")
809 and isinstance(func.__self__, torch.nn.Module)
810 and func.__name__ == "forward"
811 ):
812 if example_inputs is None:
File /opt/conda/lib/python3.9/site-packages/torch/jit/_trace.py:1056, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
1054 else:
1055 example_inputs = make_tuple(example_inputs)
-> 1056 module._c._create_method_from_trace(
1057 method_name,
1058 func,
1059 example_inputs,
1060 var_lookup_fn,
1061 strict,
1062 _force_outplace,
1063 argument_names,
1064 _store_inputs
1065 )
1067 check_trace_method = module._c._get_method(method_name)
1069 # Check the trace against new traces created from user-specified inputs
File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1488, in Module._slow_forward(self, *input, **kwargs)
1486 recording_scopes = False
1487 try:
-> 1488 result = self.forward(*input, **kwargs)
1489 finally:
1490 if recording_scopes:
File /opt/conda/lib/python3.9/site-packages/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py:609, in VisionEncoderDecoderModel.forward(self, pixel_values, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)
604 decoder_input_ids = shift_tokens_right(
605 labels, self.config.pad_token_id, self.config.decoder_start_token_id
606 )
608 # Decode
--> 609 decoder_outputs = self.decoder(
610 input_ids=decoder_input_ids,
611 attention_mask=decoder_attention_mask,
612 encoder_hidden_states=encoder_hidden_states,
613 encoder_attention_mask=encoder_attention_mask,
614 inputs_embeds=decoder_inputs_embeds,
615 output_attentions=output_attentions,
616 output_hidden_states=output_hidden_states,
617 use_cache=use_cache,
618 past_key_values=past_key_values,
619 return_dict=return_dict,
620 **kwargs_decoder,
621 )
623 # Compute loss independent from decoder (as some shift the logits inside them)
624 loss = None
File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1488, in Module._slow_forward(self, *input, **kwargs)
1486 recording_scopes = False
1487 try:
-> 1488 result = self.forward(*input, **kwargs)
1489 finally:
1490 if recording_scopes:
File /opt/conda/lib/python3.9/site-packages/transformers/models/trocr/modeling_trocr.py:959, in TrOCRForCausalLM.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask, cross_attn_head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
956 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
958 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 959 outputs = self.model.decoder(
960 input_ids=input_ids,
961 attention_mask=attention_mask,
962 encoder_hidden_states=encoder_hidden_states,
963 encoder_attention_mask=encoder_attention_mask,
964 head_mask=head_mask,
965 cross_attn_head_mask=cross_attn_head_mask,
966 past_key_values=past_key_values,
967 inputs_embeds=inputs_embeds,
968 use_cache=use_cache,
969 output_attentions=output_attentions,
970 output_hidden_states=output_hidden_states,
971 return_dict=return_dict,
972 )
974 logits = self.output_projection(outputs[0])
976 loss = None
File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1488, in Module._slow_forward(self, *input, **kwargs)
1486 recording_scopes = False
1487 try:
-> 1488 result = self.forward(*input, **kwargs)
1489 finally:
1490 if recording_scopes:
File /opt/conda/lib/python3.9/site-packages/transformers/models/trocr/modeling_trocr.py:642, in TrOCRDecoder.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask, cross_attn_head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
639 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
641 if inputs_embeds is None:
--> 642 inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
644 if self.config.use_learned_position_embeddings:
645 embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)
File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1488, in Module._slow_forward(self, *input, **kwargs)
1486 recording_scopes = False
1487 try:
-> 1488 result = self.forward(*input, **kwargs)
1489 finally:
1490 if recording_scopes:
File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/sparse.py:162, in Embedding.forward(self, input)
161 def forward(self, input: Tensor) -> Tensor:
--> 162 return F.embedding(
163 input, self.weight, self.padding_idx, self.max_norm,
164 self.norm_type, self.scale_grad_by_freq, self.sparse)
File /opt/conda/lib/python3.9/site-packages/torch/nn/functional.py:2210, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
2204 # Note [embedding_renorm set_grad_enabled]
2205 # XXX: equivalent to
2206 # with torch.no_grad():
2207 # torch.embedding_renorm_
2208 # remove once script supports set_grad_enabled
2209 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self
I am sure I have resize the embedding of model by tokenizer. Does anyone have a proper idea of saving VisionEncoderDecoderModel or other transformers Seq2seq to TorchScript?