How to accelerate.pepare() two optimizer with different LR for two separate models?

Now I basicly doing such:

model1, model2 = accelerator.prepare(model1, model2)
optimizer1 = optimizer_cls(
        model1.parameters(),
        lr=config.train.learning_rate_1,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        weight_decay=config.train.adam_weight_decay,
        eps=config.train.adam_epsilon,
    )
optimizer2 = optimizer_cls(
        model2.parameters(),
        lr=config.train.learning_rate_2,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        weight_decay=config.train.adam_weight_decay,
        eps=config.train.adam_epsilon,
    )
optimizer1, optimizer2 = accelerator.prepare(optimizer1, optimizer2)
for epoch in range(config.num_epochs):
    output1 = model1(input)
    output2 = model2(input)
    loss = loss1(output1) + loss2(output2)
    accelerator.backward(loss)
    if accelerator.sync_gradients:
        accelerator.clip_grad_norm_(model1.parameters(), config.train.max_grad_norm)
        accelerator.clip_grad_norm_(model2.parameters(), config.train.max_grad_norm)
        optimizer1.step()
        optimizer2.step()
        optimizer1.zero_grad()
        optimizer2.zero_grad()
    accelerator.save_state()

Am I doing this right? Thanks!

Yes that is indeed correct

Is there any documentation on how to use prepare in various cases? I have two models but want to use the same optimizer. How should I use the prepare() in such a case?

Iā€™m training the two models in tandem, say model1 and model2. When calling the prepare() function, can I call it separately as in the below code?

model1, dataloader, optimizer = accelerator.prepare(model1, dataloader, optimizer)
if some_condition:
    model2 = acceletator.prepare(model2)

Or do I have to call prepare() only once, as in the below code?

if not some_condition:
    model1, dataloader, optimizer = accelerator.prepare(model1, dataloader, optimizer)
else:
    model1, model2, dataloader, optimizer = acceletator.prepare(model1, model2, dataloader, optimizer)

Iā€™ve asked this on stackoverflow as well