I am trying to fine-tune the google/deplot model, I have images of plots with annotations in json. I am able to populate the train and test datasets, but when I invoke the train method from training, the ImageDataCollator is called (call method) with empty batch. The code is below, please advise - is there anything else I need to do so that the ImageDataCollator would be called with non-empty batch? Thanks for any hint.
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from transformers import Pix2StructVisionConfig, Pix2StructVisionModel
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import Dataset
from PIL import Image
import os
import torch
from torchvision import transforms
import json
import logging
from transformers import TrainerCallback
# Custom Image Data Collator
class ImageDataCollator:
def __init__(self, desired_size=(768, 768)):
self.transform = transforms.Compose([
transforms.Resize(desired_size),
transforms.ToTensor(),
# Add any other transformations here (e.g., normalization)
])
def __call__(self, batch):
images = []
texts = []
annotations = []
for item in batch:
if 'image' in item:
images.append(self.transform(item["image"]))
texts.append(item["text"])
annotations.append(item["annotation"])
else:
print("Missing 'image' key in batch item:", item)
if not images:
print("No valid images in batch.")
# Return placeholders for images, texts, and annotations
return {
'images': torch.zeros(1, 3, *self.transform.transforms[0].size),
'texts': [],
'annotations': []
}
# Stack the images into a single tensor
images_tensor = torch.stack(images)
texts_tensor = torch.stack(texts)
annotations_tensor = torch.stack(annotations)
data_tensors = {
'images': images_tensor,
'texts': texts_tensor,
'annotations': annotations_tensor
}
# Return a dictionary with images, texts, and annotations
return data_tensors
class ImageDataset(Dataset):
def __init__(self, image_dir, annotation_dir=None):
self.image_dir = image_dir
self.annotation_dir = annotation_dir
self.images = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image_name = self.images[idx]
image_path = os.path.join(self.image_dir, image_name)
try:
image = Image.open(image_path).convert("RGB")
data = {"image": image, "name": image_name}
except Exception as e:
print(f"Error opening image {image_path}: {e}")
# Return an empty dictionary or handle this case as needed
data = {"name": image_name}
if self.annotation_dir:
annotation_name = os.path.splitext(image_name)[0] + '.json'
annotation_path = os.path.join(self.annotation_dir, annotation_name)
if os.path.exists(annotation_path):
with open(annotation_path, 'r') as file:
annotation = json.load(file)
data["annotation"] = annotation
else:
data["annotation"] = ""
data["text"] = data["annotation"]
return data
def load_dataset(train_image_path, train_annotation_path, test_image_path, tokenizer):
train_dataset = ImageDataset(train_image_path, train_annotation_path)
test_dataset = ImageDataset(test_image_path) # No annotations for test data
data_collator = ImageDataCollator(desired_size=(768, 768))
return train_dataset, test_dataset, data_collator
def train(model, train_dataset, test_dataset, data_collator):
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
evaluation_strategy="epoch",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=data_collator,
)
trainer.train()
def save(model, processor, savedir):
model.save_pretrained(savedir)
processor.save_pretrained(savedir)
def main():
logging.basicConfig(level=logging.DEBUG)
train_image_path = "D:\\Data\\Dokumenty\\Priv\\PfB\\Figlinq\\Python\\DeplotTraining\\benetech-making-graphs-accessible\\train\\images"
train_annotation_path = "D:\\Data\\Dokumenty\\Priv\\PfB\\Figlinq\\Python\\DeplotTraining\\benetech-making-graphs-accessible\\train\\annotations"
test_image_path = "D:\\Data\\Dokumenty\\Priv\\PfB\\Figlinq\\Python\\DeplotTraining\\benetech-making-graphs-accessible\\test\\images"
read_model_path = "D:\\Data\\Dokumenty\\Priv\\PfB\\Figlinq\\Python\\DeplotTraining\\SavedModelsP2S.Source\\"
save_model_path = "D:\\Data\\Dokumenty\\Priv\\PfB\\Figlinq\\Python\\DeplotTraining\\SavedModelsP2S.1\\"
model_name = "google/deplot"
processor = Pix2StructProcessor.from_pretrained(read_model_path)
model = Pix2StructForConditionalGeneration.from_pretrained(read_model_path)
tokenizer = AutoTokenizer.from_pretrained(read_model_path)
train_dataset, test_dataset, data_collator = load_dataset(train_image_path, train_annotation_path, test_image_path, tokenizer)
train(model, train_dataset, test_dataset, data_collator)
save(model, processor, save_model_path)
if __name__ == "__main__":
main()