According to here, that seems like it would work fine: Multiple gradient updates with two separate losses and two classifiers sharing the same encoder - nlp - PyTorch Forums. During accumulation it’d be attached separately/referenced separately I believe, with wouldn’t cause them to entangle.
The key here is don’t call zero_grad()
before if possible. So I think we can essentially rewrite this as:
# Suppose gradient_accumulation is set to 2.
optimizer = optim(unet.parameters())
with accelerator.accumulate(unet):
outputs = unet(input)
loss1 = loss_func1(outputs)
loss1.backward(retain_graph=True)
loss2 = loss_func2(outputs)
loss2.backward()
optimizer.step()
optimizer.zero_grad()
And this should work/be okay, the key is ensuring retain_graph=True