How to train the GIT model on particular datasets

Hello everyone,

I wanted to try my hand at using the GIT model for Image Captioning and VQA.
I want to fine tune it using the ViWiz dataset adapted for each problem (VizWiz for Image Captioning and VizWiz for VQA).
However I have trouble finding how to use the Trainer for fine tuning the GIT model.
I don’t understand how the Dataset should be passed to the trainer.

I also tried using a torch training loop, but on this end I don’t understant what should the labels be.

Following is an example of the training loop for the VQA task.

    for epoch in range(args.num_epochs):
        for idx, batch in enumerate(train_dataloader):
            # labels and input_ids are tokenized
            input_ids = batch['input_ids']
            input_ids = [processor.tokenizer.cls_token_id] + input_ids
            input_ids = torch.tensor(input_ids).unsqueeze(0).to('cuda')

            # Here the labels are on the form "Question ? Answer"
            labels = batch['labels']
            labels = [processor.tokenizer.cls_token_id] + labels
            labels = torch.tensor(labels).unsqueeze(0).to('cuda')
            pixel_values = batch['pixel_values'].to('cuda')
            outputs = model(input_ids, pixel_values=pixel_values, labels=labels)

            loss = outputs.loss


            if (progress_bar.n + 1) % args.save_steps == 0:
                print("reaching checkpoint, saving")
                trainer.save_model(join(args.output_dir, f'checkpoint-{progress_bar.n + 1}'))
      {'optimizer_state_dict': optimizer.state_dict(),
                            'lr_state_dict': lr_scheduler.state_dict(),
                            }, join(args.output_dir, 'optimizer')

            if (progress_bar.n + 1) == num_training_steps:
      {'optimizer_state_dict': optimizer.state_dict(),
                            'lr_state_dict': lr_scheduler.state_dict()
                            }, join(args.output_dir, 'optimizer')
        print(f"Epoch {epoch} : loss = {}")

however this method does not seem to work as it produce a weird output.
The output is either not answering or generating a random word multiple times. So I figred something might be wrong.

Thank you !

Hi! I have little idea about how to finetuning git for VQA tasks as well, however, I did find an excellent demo on how to finetune git for image caption.
Transformers-Tutorials/GIT at master · NielsRogge/Transformers-Tutorials (
I hope you might find it helpful. :grinning:

Nice !

Thank you very much.
I will look into it and post an update.