Performing gradient accumulation with Accelerate

It seems that there is no effect of gradient accumulation = 2 at all.

My code :

for epoch in range(init_epoch, args.num_epoch + 1):
    for iteration, (x, y) in enumerate(data_loader):
        x_0 = x.to(device, dtype=dtype, non_blocking=True)
        y = None if not use_label else y.to(device, non_blocking=True)
        model.zero_grad()
        if is_latent_data:
            z_0 = x_0 * args.scale_factor
        else:
            z_0 = first_stage_model.encode(x_0).latent_dist.sample().mul_(args.scale_factor)
        # sample t
        t = torch.rand((z_0.size(0),), dtype=dtype, device=device)
        t = t.view(-1, 1, 1, 1)
        z_1 = torch.randn_like(z_0)
        # 1 is real noise, 0 is real data
        z_t = (1 - t) * z_0 + (1e-5 + (1 - 1e-5) * t) * z_1
        u = (1 - 1e-5) * z_1 - z_0
        # estimate velocity
        v = model(t.squeeze(), z_t, y)
        loss = loss.mean()            
        accelerator.backward(loss)

        #change below
        if (iteration + 1) % args.gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            
            # clear gradient
            model.zero_grad() 
            global_step += 1
            log_steps += 1
        #change above from here

        if iteration % 100 == 0:
            if accelerator.is_main_process:
                # Measure training speed:
                end_time = time()
                steps_per_sec = log_steps / (end_time - start_time)
                accelerator.print(
                    "epoch {} iteration{}, Loss: {}, Train Steps/Sec: {:.2f}".format(
                        epoch, iteration, loss.item(), steps_per_sec
                    )
                )
                # Reset monitoring variables:
                log_steps = 0
                start_time = time()

code ref : LFM/train_flow_latent.py at main Ā· VinAIResearch/LFM Ā· GitHub

You need to also use the gradient accumulation wrapper as detailed/shown here: accelerate/examples/by_feature/gradient_accumulation.py at main Ā· huggingface/accelerate Ā· GitHub