Integrating accelerate to the train code

Hello, I am trying to use accelerate to train my vision transformer model using 4 gpus in a single machine.

def main():       
    processor = Pix2StructProcessor.from_pretrained('./deplot_models/deplot_base_model/')
    model = Pix2StructForConditionalGeneration.from_pretrained('./deplot_models/deplot_base_model/')
    
    with open('data/full_vocab.txt', 'r+') as f:
        full_v = [v.strip('\n') for v in f.readlines()]
        
    new_t = full_v[50345:]

    processor.tokenizer.add_tokens(new_t)
    model.resize_token_embeddings(len(processor.tokenizer))
    
    def collator(batch):
        new_batch = {"flattened_patches":[], "attention_mask":[]}
        texts = [item["text"] for item in batch]

        text_inputs = processor(text=texts, padding='max_length', return_tensors="pt", add_special_tokens=True, max_length=512)

        new_batch["labels"] = text_inputs.input_ids
        new_batch["labels"][new_batch["labels"] == processor.tokenizer.pad_token_id] = -100
        for item in batch:
            new_batch["flattened_patches"].append(item["flattened_patches"])
            new_batch["attention_mask"].append(item["attention_mask"])

        new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
        new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

        return new_batch
    
    train_dataset = DeplotDataset('./test/images', './test/targets/', processor)
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1, collate_fn=collator)
    
    optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, lr=0.01, weight_decay=1e-05)
    # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=100000, num_training_steps=2_000_000*EPOCHS)
    
    accelerator = Accelerator()
    device = accelerator.device
    
    model, optimizer, training_dataloader, scheduler = accelerator.prepare(model, optimizer, train_dataloader, scheduler)
    
    model.to(device)
    # model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    # model.cuda()

    model.train()

    for epoch in range(EPOCHS):
        print("Epoch:", epoch)
        for idx, batch in enumerate(train_dataloader):
#             if idx % 50 == 0:
#                 print(torch.cuda.max_memory_allocated())
            optimizer.zero_grad()
            labels = batch.pop("labels").to(device)
            flattened_patches = batch.pop("flattened_patches").to(device)
            attention_mask = batch.pop("attention_mask").to(device)

            outputs = model(flattened_patches=flattened_patches,
                            attention_mask=attention_mask,
                            labels=labels)

            loss = outputs.loss

            if idx % 50 == 0:
                print(f"Idx: {idx}, Loss: {loss.item()}")

            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()
            
            del labels, flattened_patches, attention_mask, outputs, loss

            if idx % 100000 == 0:
                print(f'Saving model {idx}')
                with torch.no_grad():
                    unwrapped_model = accelerator.unwrap_model(model)
                    accelerator.save({
                        'epoch': epoch,
                        'model': unwrapped_model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()
                    }, f'../../storage/deplot_checkpoints/extend_full_200/extend_full_ep_{epoch}_step_{idx}.pt')
                    print('Saved model!')


    #         if (epoch + 1) % 2 == 0:
    #             model.eval()

    #             predictions = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask)        
    #             print("Predictions:", processor.batch_decode(predictions, skip_special_tokens=True))

    #             model.train()

    with torch.no_grad():
        unwrapped_model = accelerator.unwrap_model(model)
        accelerator.save({
            'epoch': EPOCHS,
            'model': unwrapped_model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
        }, '../../storage/deplot_checkpoints/extend_full_200/extend_full_full.pt')

This is my main function and this is my dataset class.

MAX_PATCHES = 1024

class DeplotDataset(Dataset):
    def __init__(self, image_folder, text_folder, processor, transform=None):
        self.image_folder = image_folder
        self.text_folder = text_folder
        self.processor = processor
        self.transform = transform

        self.image_filenames = sorted(os.listdir(image_folder))
        self.text_filenames = sorted(os.listdir(text_folder))

    def __len__(self):
        return len(self.image_filenames)    

    def __getitem__(self, index):
        image_filename = self.image_filenames[index]
        text_filename = self.text_filenames[index]

        image_path = os.path.join(self.image_folder, image_filename)
        text_path = os.path.join(self.text_folder, text_filename)

        image = Image.open(image_path)
        with open(text_path, 'r') as f:
            text = f.read()

        if self.transform:
            image = self.transform(image)

        encoding = self.processor(images=image, text="Generate underlying data table of the figure below:", return_tensors="pt", add_special_tokens=True, max_patches=MAX_PATCHES)
        
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = text
        return encoding

I used accelerate launch --multi_gpu deplot_train.py to run my code. The loss seemed to converge (from 10 to 2) but when I made an inference, it seemed to spit out bunch of gibberish. Can anyone guide me on how to change my code or use accelerate properly?