Pix2struct based model ddp code conversion

Hi, I’m trying to convert this below code using deepspeed/pytorch lightning so that I can use multiple gpu in a single machine using ddp.

device = "cuda:3" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()

for epoch in range(EPOCHS):
    with open(f"{MODEL_NAME}_full_loss.txt", "a") as f:
        f.write(f"Epoch: {epoch+1}\n")

    print("Epoch:", epoch)
    for idx, batch in enumerate(train_dataloader):
        labels = batch.pop("labels").to(device)
        flattened_patches = batch.pop("flattened_patches").to(device)
        attention_mask = batch.pop("attention_mask").to(device)

        outputs = model(
            flattened_patches=flattened_patches,
            attention_mask=attention_mask,
            labels=labels,
        )

        loss = outputs.loss

        if idx % 50 == 0:
            with open(f"{MODEL_NAME}_full_loss.txt", "a") as f:
                f.write(f"Idx: {idx}, Loss: {loss.item()}\n")
            print(f"Idx: {idx}, Loss: {loss.item()}")

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
    
        if idx % 10000 == 0:
            torch.save(
                {
                    "epoch": (epoch + 1),
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                },
                f"./deplot_checkpoints/{MODEL_NAME}/{MODEL_NAME}_ep_{(epoch + 1)}_step_{idx}.pt",
            )
            
    scheduler.step()

I’ve been trying to convert the code using pytorch lightning so that it can run in multiple gpus, but to no avail.
Below is the code I’ve been using.

    processor = Pix2StructProcessor.from_pretrained("./deplot_models/warmup4/")
    model = Pix2StructForConditionalGeneration.from_pretrained(
        "./deplot_models/warmup4/"
    )
    def collator(batch):
        new_batch = {"flattened_patches": [], "attention_mask": []}
        texts = [item["text"] for item in batch]

        text_inputs = processor(
            text=texts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            add_special_tokens=True,
            max_length=MAX_BATCH_LEN,                                                               
        )

        new_batch["labels"] = text_inputs.input_ids
        new_batch["labels"][new_batch["labels"] == processor.tokenizer.pad_token_id] = -100
        for item in batch:
            new_batch["flattened_patches"].append(item["flattened_patches"])
            new_batch["attention_mask"].append(item["attention_mask"])

        new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
        new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

        return new_batch
    train_dataset = DeplotDataset(
        "./dataset4/images/", "./dataset4/targets/", processor, split="train"
    )
    val_dataset = DeplotDataset(
        "./dataset4/images/", "./dataset4/targets/", processor, split="validation"
    )
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator, num_workers=16)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collator, num_workers=16)
    config = {
        "max_steps": len(train_dataset) // BATCH_SIZE * EPOCHS,
        "num_warmup_steps": len(train_dataset) / BATCH_SIZE * EPOCHS // 50,
        "check_val_every_n_steps": 100,
        "gradient_clip_val": 1.0,
        "accumulate_grad_batches": 64,
        "verbose": True
    }
    pl_module = Deplot(config, processor, model)
    trainer = pl.Trainer(
        accelerator="gpu",
        strategy="ddp",
        devices=[0, 1, 3],
        max_steps=config.get("max_steps"),
#         check_val_every_n_epoch=config.get("check_val_every_n_steps"),
        log_every_n_steps=100,
        enable_progress_bar=True,
        enable_model_summary=True,
        gradient_clip_val=config.get("gradient_clip_val"),
        accumulate_grad_batches=config.get("accumulate_grad_batches")
    )
    trainer.fit(pl_module)

Check the following notebook to get some idea: