It may be useful for you.
ddp_model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
with accelerator.join_uneven_inputs([ddp_model], even_batches=False):
for input, output in dataloader:
outputs = model(input)
loss = loss_func(outputs)
loss.backward()
optimizer.step()
optimizer.zero_grad()