I’m trying to finetune a stable diffusion model using the example finetuning code in huggingface diffusers (https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) The main training loop uses accelerate gradient accumulation and is similar to this pseudo code example:
global_step = 0.0
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# steps for generating model inputs (noisy latents, timesteps, etc.)
...
# model forward
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
global_step += 1
train_loss = 0.0
I noticed that the gradient norm (stored in the grad_norm variable) sometimes explodes and becomes a very large value for certain samples in my dataset. For some reason this will cause the loss to diverge even after applying gradient clipping. I would thus like to define a threshold and if the grad norm is above the threshold, I would like to discard this entire batch (that has a batch size equal to the total batch size after applying gradient accumulation) and don’t do any gradient update. I wonder if I can do something like this:
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
if grad_norm.item() > grad_norm_threshold:
lr_scheduler.step()
optimizer.zero_grad()
continue
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
I’m not sure if this can correctly skip the update of the current batch. If not, what should I do here? Thanks in advance.