Hi,
I am trying to train the Deplot model using this following Pix2Struct example:
The bulk of the code is almost the same but with some minor adjustments.
My dataset
class DeplotDataset(Dataset):
def __init__(self, image_folder, text_folder, processor, transform=None):
self.image_folder = image_folder
self.text_folder = text_folder
self.processor = processor
self.transform = transform
self.image_filenames = sorted(os.listdir(image_folder))
self.text_filenames = sorted(os.listdir(text_folder))
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, index):
image_filename = self.image_filenames[index]
text_filename = self.text_filenames[index]
image_path = os.path.join(self.image_folder, image_filename)
text_path = os.path.join(self.text_folder, text_filename)
image = Image.open(image_path)
with open(text_path, 'r') as f:
text = f.read()
if self.transform:
image = self.transform(image)
encoding = self.processor(images=image, text="Generate underlying data table of the figure below:", return_tensors="pt", add_special_tokens=True, max_patches=MAX_PATCHES)
encoding = {k:v.squeeze() for k,v in encoding.items()}
encoding["text"] = text
return encoding
def collator(batch):
new_batch = {"flattened_patches":[], "attention_mask":[]}
texts = [item["text"] for item in batch]
text_inputs = processor(text=texts, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=True, max_length=20)
new_batch["labels"] = text_inputs.input_ids
for item in batch:
new_batch["flattened_patches"].append(item["flattened_patches"])
new_batch["attention_mask"].append(item["attention_mask"])
new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])
return new_batch
But I get this following error:
ValueError: Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray, but got <class 'NoneType'>.