It seems that there is no effect of gradient accumulation = 2 at all.
My code :
for epoch in range(init_epoch, args.num_epoch + 1):
for iteration, (x, y) in enumerate(data_loader):
x_0 = x.to(device, dtype=dtype, non_blocking=True)
y = None if not use_label else y.to(device, non_blocking=True)
model.zero_grad()
if is_latent_data:
z_0 = x_0 * args.scale_factor
else:
z_0 = first_stage_model.encode(x_0).latent_dist.sample().mul_(args.scale_factor)
# sample t
t = torch.rand((z_0.size(0),), dtype=dtype, device=device)
t = t.view(-1, 1, 1, 1)
z_1 = torch.randn_like(z_0)
# 1 is real noise, 0 is real data
z_t = (1 - t) * z_0 + (1e-5 + (1 - 1e-5) * t) * z_1
u = (1 - 1e-5) * z_1 - z_0
# estimate velocity
v = model(t.squeeze(), z_t, y)
loss = loss.mean()
accelerator.backward(loss)
#change below
if (iteration + 1) % args.gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step()
# clear gradient
model.zero_grad()
global_step += 1
log_steps += 1
#change above from here
if iteration % 100 == 0:
if accelerator.is_main_process:
# Measure training speed:
end_time = time()
steps_per_sec = log_steps / (end_time - start_time)
accelerator.print(
"epoch {} iteration{}, Loss: {}, Train Steps/Sec: {:.2f}".format(
epoch, iteration, loss.item(), steps_per_sec
)
)
# Reset monitoring variables:
log_steps = 0
start_time = time()
code ref : LFM/train_flow_latent.py at main 路 VinAIResearch/LFM 路 GitHub