I’m trying to convert a model I finetuned using (the model is here: Salesforce/blip-image-captioning-base · Hugging Face). I’m trying to convert it to TorchScript so I can use it for an android app I used two techniques to try this but I always get a shape error on both.
from transformers import BlipConfig
from torch.utils.mobile_optimizer import optimize_for_mobile
# Creating a dummy input
test_text = "this is a cat"
text_inputs = processor(text=test_text, return_tensors="pt")
tokenized_text = tokenizer.tokenize(test_text)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_inputs = processor(images=image, return_tensors="pt")
# Instantiating the model
model = BlipForConditionalGeneration(BlipConfig())
# The model needs to be in evaluation mode
model.eval()
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torchscript=True, local_files_only=True)
model.load_state_dict(torch.load("state_dict.pkl"))
# Creating the trace
traced_script_module = torch.jit.trace(model, {'vision_model': [image_inputs['pixel_values']],
'text_decoder': [text_inputs['input_ids'], text_inputs['attention_mask']]})
#traced_script_module = torch.jit.trace(model, {'vision_model': [image_inputs['pixel_values']],
#'text_decoder': [text_inputs['input_ids'], text_inputs['attention_mask']]})
#torch.jit.save(traced_model, "traced_model.pt")
#traced_script_module = torch.jit.trace(best_model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("optimised_image_caption_model.ptl")
Give the error
File ~\anaconda3\envs\pytorch_env\Lib\site-packages\transformers\models\blip\modeling_blip.py:247, in BlipVisionEmbeddings.forward(self, pixel_values)
246 def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
--> 247 batch_size = pixel_values.shape[0]
248 target_dtype = self.patch_embedding.weight.dtype
249 patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
AttributeError: 'dict' object has no attribute 'shape'
Whilst the other technique I tried
class DecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
scripted_gate = torch.jit.script(DecisionGate())
# Instantiating the model
temodel = BlipForConditionalGeneration(BlipConfig())
# The model needs to be in evaluation mode
temodel.eval()
temodel = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torchscript=True)
temodel.load_state_dict(torch.load("state_dict.pkl"))
temodel = model(scripted_gate)
scripted_cell = torch.jit.script(temodel)
print(scripted_gate.code)
print(scripted_cell.code)
Gives the error
File ~\anaconda3\envs\pytorch_env\Lib\site-packages\transformers\models\blip\modeling_blip.py:247, in BlipVisionEmbeddings.forward(self, pixel_values)
246 def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
--> 247 batch_size = pixel_values.shape[0]
248 target_dtype = self.patch_embedding.weight.dtype
249 patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
File ~\anaconda3\envs\pytorch_env\Lib\site-packages\torch\jit\_script.py:784, in RecursiveScriptModule.__getattr__(self, attr)
781 self.__dict__[attr] = script_method
782 return script_method
--> 784 return super().__getattr__(attr)
File ~\anaconda3\envs\pytorch_env\Lib\site-packages\torch\jit\_script.py:501, in ScriptModule.__getattr__(self, attr)
499 def __getattr__(self, attr):
500 if "_actual_script_module" not in self.__dict__:
--> 501 return super().__getattr__(attr)
502 return getattr(self._actual_script_module, attr)
File ~\anaconda3\envs\pytorch_env\Lib\site-packages\torch\nn\modules\module.py:1614, in Module.__getattr__(self, name)
1612 if name in modules:
1613 return modules[name]
-> 1614 raise AttributeError("'{}' object has no attribute '{}'".format(
1615 type(self).__name__, name))
AttributeError: 'RecursiveScriptModule' object has no attribute 'shape'
Any help will be much appreciated thanks