ValueError: Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray, but got

I am trying to train the Deplot model using Huggingface Library but facing Value Error.

I follow the pix2struct notebook as suggested in Deplot Code.
https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb

I have successfully trained pix2struct-base models using the same data format (image and text) without encountering any issues. However, I face this particular problem when training the “Deplot” model.

My Dataset format:

Dataset({
features: [‘image’, ‘text’],
num_rows: 14289
})

The code below displays “True”, meaning the format is correct.

print(isinstance(dataset[0][‘image’], Image.Image))

Code I used:

from torch.utils.data import Dataset, DataLoader

MAX_PATCHES = 1024

class ImageCaptioningDataset(Dataset):
def init(self, dataset, processor):
self.dataset = dataset
self.processor = processor

def __len__(self):
    return len(self.dataset)

def __getitem__(self, idx):
    item = self.dataset[idx]
    encoding = self.processor(images=item["image"], text="Generate underlying data table of the figure below:", return_tensors="pt", add_special_tokens=True, max_patches=MAX_PATCHES)
    encoding = {k:v.squeeze() for k,v in encoding.items()}
    encoding["text"] = item["text"]
    return encoding

from transformers import AutoProcessor, Pix2StructForConditionalGeneration

processor = AutoProcessor.from_pretrained(“google/deplot”)
model = Pix2StructForConditionalGeneration.from_pretrained(“google/deplot”)

import torch

def collator(batch):
new_batch = {“flattened_patches”:, “attention_mask”:}
texts = [item[“text”] for item in batch]

text_inputs = processor(text=texts, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=True, max_length=512)

new_batch["labels"] = text_inputs.input_ids

for item in batch:
    new_batch["flattened_patches"].append(item["flattened_patches"])
    new_batch["attention_mask"].append(item["attention_mask"])

new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

return new_batch

train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2, collate_fn=collator)

import torch
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup
import os

EPOCHS = 5000

optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, lr=0.01, weight_decay=1e-05)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=40000)

device = torch.device(“cuda:2”) if torch.cuda.is_available() else “cpu”
model.to(device)

model.train()

for epoch in range(EPOCHS):
print(“Epoch:”, epoch)

for idx, batch in enumerate(train_dataloader): 
    labels = batch.pop("labels").to(device)
    flattened_patches = batch.pop("flattened_patches").to(device)
    attention_mask = batch.pop("attention_mask").to(device)

    outputs = model(flattened_patches=flattened_patches,
                    attention_mask=attention_mask,
                    labels=labels)
    
    loss = outputs.loss

    print("Loss:", loss.item())

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()  # Update the learning rate scheduler

    if (epoch + 1) % 20 == 0:
        model.eval()

        predictions = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask)        
        print("Predictions:", processor.batch_decode(predictions, skip_special_tokens=True))

        model.train()

Error I am facing:


ValueError Traceback (most recent call last) Cell In[20], line 18 15 for epoch in range(EPOCHS): 16 print(“Epoch:”, epoch) —> 18 for idx, batch in enumerate(train_dataloader): 19 labels = batch.pop(“labels”).to(device) 20 flattened_patches = batch.pop(“flattened_patches”).to(device) File ~/anaconda3/envs/deplot_3/lib/python3.9/site-packages/torch/utils/data/dataloader.py:652, in _BaseDataLoaderIter.next(self) 649 if self._sampler_iter is None: 650 # TODO(Bug in dataloader iterator found by mypy · Issue #76750 · pytorch/pytorch · GitHub) 651 self._reset() # type: ignore[call-arg] → 652 data = self._next_data() 653 self._num_yielded += 1 654 if self._dataset_kind == _DatasetKind.Iterable and \ 655 self._IterableDataset_len_called is not None and \ 656 self._num_yielded > self._IterableDataset_len_called: File ~/anaconda3/envs/deplot_3/lib/python3.9/site-packages/torch/utils/data/dataloader.py:692, in _SingleProcessDataLoaderIter._next_data(self) 690 def _next_data(self): 691 index = self._next_index() # may raise StopIteration → 692 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 693 if self._pin_memory: 694 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

133 “Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or " 134 f"jax.ndarray, but got {type(images)}.” 135 )
ValueError: Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray, but got .

Same question here! My guess is that since our new deplot processor aggregates both the bert-tokenizer processor and the pix2struct processor, it requires ‘images=’ parameter as used in the getitem method from the Dataset class but I have no idea what the images should be in the collator function

This pointed me towards a solution.

When running from the notebook, images and text are processed separately, images in ImageCaptioningDataset.__call__ and text in collator. Therefore processor does not get an image input when processing the text. However, since processor.image_processor.is_vqa is True, it expects an image.

In processor, the input should be treated as text only, but the conditional evaluates to False due to is_vqa, and self.image_processor(images, ...) gets called with None, causing the error message.

You can do a quick fix by setting processor.image_processor.is_vqa = False before iterating the dataloader.

Thank you, It worked :+1: