Hi, I am trying to get an inference of an image using the Pix2struct vision transformer model. Currently, I am generating one inference at a time and the code I am using is in below.
processor = Pix2StructProcessor.from_pretrained(
"google/deplot", is_vqa=True
)
model = Pix2StructForConditionalGeneration.from_pretrained(
"google/deplot", is_vqa=True
).to(device)
with open('./data/test_imgs/test.png', "rb") as f:
image = Image.open(f).convert("RGB")
inputs = processor(
images=image,
text="Generate underlying data table of the figure below:",
return_tensors="pt",
).to(device)
predictions = model.generate(**inputs, max_new_tokens=512)
deplot_result = processor.decode(
predictions[0], skip_special_tokens=True
)
print(deplot_result)
However, the inference time for this method is ~45 secs/image, which is not viable for our project. Is there a way to convert this code into using batches so that I can generate multiple predictions at the same time?