I am trying to train the Deplot model using Huggingface Library but facing Value Error.
I follow the pix2struct notebook as suggested in Deplot Code.
https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb
I have successfully trained pix2struct-base models using the same data format (image and text) without encountering any issues. However, I face this particular problem when training the âDeplotâ model.
My Dataset format:
Dataset({
features: [âimageâ, âtextâ],
num_rows: 14289
})
The code below displays âTrueâ, meaning the format is correct.
print(isinstance(dataset[0][âimageâ], Image.Image))
Code I used:
from torch.utils.data import Dataset, DataLoader
MAX_PATCHES = 1024
class ImageCaptioningDataset(Dataset):
def init(self, dataset, processor):
self.dataset = dataset
self.processor = processordef __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] encoding = self.processor(images=item["image"], text="Generate underlying data table of the figure below:", return_tensors="pt", add_special_tokens=True, max_patches=MAX_PATCHES) encoding = {k:v.squeeze() for k,v in encoding.items()} encoding["text"] = item["text"] return encoding
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
processor = AutoProcessor.from_pretrained(âgoogle/deplotâ)
model = Pix2StructForConditionalGeneration.from_pretrained(âgoogle/deplotâ)import torch
def collator(batch):
new_batch = {âflattened_patchesâ:, âattention_maskâ:}
texts = [item[âtextâ] for item in batch]text_inputs = processor(text=texts, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=True, max_length=512) new_batch["labels"] = text_inputs.input_ids for item in batch: new_batch["flattened_patches"].append(item["flattened_patches"]) new_batch["attention_mask"].append(item["attention_mask"]) new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"]) new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"]) return new_batch
train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2, collate_fn=collator)import torch
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup
import osEPOCHS = 5000
optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, lr=0.01, weight_decay=1e-05)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=40000)device = torch.device(âcuda:2â) if torch.cuda.is_available() else âcpuâ
model.to(device)model.train()
for epoch in range(EPOCHS):
print(âEpoch:â, epoch)for idx, batch in enumerate(train_dataloader): labels = batch.pop("labels").to(device) flattened_patches = batch.pop("flattened_patches").to(device) attention_mask = batch.pop("attention_mask").to(device) outputs = model(flattened_patches=flattened_patches, attention_mask=attention_mask, labels=labels) loss = outputs.loss print("Loss:", loss.item()) loss.backward() optimizer.step() optimizer.zero_grad() scheduler.step() # Update the learning rate scheduler if (epoch + 1) % 20 == 0: model.eval() predictions = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask) print("Predictions:", processor.batch_decode(predictions, skip_special_tokens=True)) model.train()
Error I am facing:
ValueError Traceback (most recent call last) Cell In[20], line 18 15 for epoch in range(EPOCHS): 16 print(âEpoch:â, epoch) â> 18 for idx, batch in enumerate(train_dataloader): 19 labels = batch.pop(âlabelsâ).to(device) 20 flattened_patches = batch.pop(âflattened_patchesâ).to(device) File ~/anaconda3/envs/deplot_3/lib/python3.9/site-packages/torch/utils/data/dataloader.py:652, in _BaseDataLoaderIter.next(self) 649 if self._sampler_iter is None: 650 # TODO(Bug in dataloader iterator found by mypy ¡ Issue #76750 ¡ pytorch/pytorch ¡ GitHub) 651 self._reset() # type: ignore[call-arg] â 652 data = self._next_data() 653 self._num_yielded += 1 654 if self._dataset_kind == _DatasetKind.Iterable and \ 655 self._IterableDataset_len_called is not None and \ 656 self._num_yielded > self._IterableDataset_len_called: File ~/anaconda3/envs/deplot_3/lib/python3.9/site-packages/torch/utils/data/dataloader.py:692, in _SingleProcessDataLoaderIter._next_data(self) 690 def _next_data(self): 691 index = self._next_index() # may raise StopIteration â 692 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 693 if self._pin_memory: 694 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
âŚ
133 âInvalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or " 134 f"jax.ndarray, but got {type(images)}.â 135 )
ValueError: Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray, but got .