Dear the team,
I was trying to finetune BLIP and so far I got an error, not sure how to solve it. Is it possible that you can give me some advice? Thanks
from PIL import Image
import requests
from transformers import BlipProcessor, BlipForQuestionAnswering
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
import torch
from PIL import Image
class VQADataset(torch.utils.data.Dataset):
"""VQA (v2) dataset."""
def __init__(self, questions, answers, image_paths, processor):
self.questions = questions
self.answers = answers
self.image_paths = image_paths
self.processor = processor
def __len__(self):
return len(self.questions)
def __getitem__(self, idx):
# get image + text
question = self.questions[idx]
answer = self.answers[idx]
image = Image.open(self.image_paths[idx]).convert("RGB")
text = question
encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
labels = self.processor.tokenizer.encode(
answer, max_length= 512, pad_to_max_length=True, return_tensors='pt'
)
encoding["labels"] = labels
# remove batch dimension
# for k,v in encoding.items(): encoding[k] = v.squeeze()
return encoding
from torch.utils.data import DataLoader
from tqdm import tqdm
def collate_fn(batch):
input_ids = [item['input_ids'] for item in batch]
pixel_values = [item['pixel_values'] for item in batch]
attention_mask = [item['attention_mask'] for item in batch]
labels = [item['labels'] for item in batch]
return batch
questions = list of questions
answers = list of corresponding answers
image_paths = list of paths of corresponding images
train_dataset = VQADataset(questions = questions,
answers = answers,
image_paths = images,
processor=processor)
test_dataset = VQADataset(questions = questions,
answers = answers,
image_paths = images,
processor=processor)
batch_size = 1
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
batch = next(iter(train_dataloader))
print(batch[0].keys()) # dict_keys(['pixel_values', 'input_ids', 'attention_mask', 'labels'])
import copy
test_input = copy.copy(batch[0]).to(device)
outputs = model(**test_input)
Example of the input:
questions = ["How many cats are there?"]
answers = ["two"]
image_paths = ["./img_125.png"]
Output:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-27-f4758beea430>](https://localhost:8080/#) in <module>
2
3 test_input = copy.copy(batch[0]).to(device)
----> 4 outputs = model(**test_input)
6 frames
[/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3024 if size_average is not None or reduce is not None:
3025 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3026 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
3027
3028
ValueError: Expected input batch_size (0) to match target batch_size (511).