What changes should be made for using Trainer with the Vision Transformer, are the keys expected by the trainer from dataset input_ids
, attention_mask
, and labels
?
class OCRDataset(torch.utils.data.Dataset):
def __init__(self, texts, tokenizer, transforms = None):
self.texts = texts
self.tokenizer = tokenizer
self.transforms = transforms
def __getitem__(self, idx):
data = generate_sample(self.texts[idx])
if data:
img, label = data
img = torch.from_numpy(img)
tokens = tokenizer(label, padding='max_length')
if self.transforms:
img = self.transforms(img)
batch = {}
batch['labels'] = tokens
batch['input_ids'] = img
return batch
transform= transforms.Compose([transforms.Normalize((0.5,), (0.5,))])
train_dataset = OCRDataset(jp_list, tokenizer, transform)
.....
.....
trainer.train()
This code throws the following error
ValueError: could not determine the shape of object type ‘BatchEncoding’