How to use Trainer with Vision Transformer

What changes should be made for using Trainer with the Vision Transformer, are the keys expected by the trainer from dataset input_ids, attention_mask, and labels?

class OCRDataset(
    def __init__(self, texts, tokenizer, transforms = None):
        self.texts = texts
        self.tokenizer = tokenizer
        self.transforms = transforms

    def __getitem__(self, idx):
        data = generate_sample(self.texts[idx])
        if data:
          img, label = data
          img = torch.from_numpy(img)
          tokens = tokenizer(label, padding='max_length')
        if self.transforms:
          img = self.transforms(img)
        batch = {}
        batch['labels'] = tokens
        batch['input_ids'] = img
        return batch

transform= transforms.Compose([transforms.Normalize((0.5,), (0.5,))])
train_dataset = OCRDataset(jp_list, tokenizer, transform)


This code throws the following error

ValueError: could not determine the shape of object type ‘BatchEncoding’


I do have a demo notebook on using the Trainer for fine-tuning the Vision Transformer here: Transformers-Tutorials/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_🤗_Trainer.ipynb at master · NielsRogge/Transformers-Tutorials · GitHub.

ViT doesn’t expect input_ids and attention_mask as input, but pixel_values instead. Note that we will add support for attention_mask in the future.

Thanks, this worked for me. Can you tell me how to load the model from a checkpoint after training? In your example, the model trained is an object of the ViTForImageClassification class, so from_pretrained() cannot be used.

from_pretrained can be used on any model, so also on ViTForImageClassification