I’m trying to finetune a stable diffusion model using the example finetuning code in huggingface diffusers (https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) The main training loop uses accelerate gradient accumulation and is similar to this pseudo code example:
global_step = 0.0 train_loss = 0.0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # steps for generating model inputs (noisy latents, timesteps, etc.) ... # model forward model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() train_loss += avg_loss.item() / args.gradient_accumulation_steps # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: grad_norm = accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() if accelerator.sync_gradients: global_step += 1 train_loss = 0.0
I noticed that the gradient norm (stored in the grad_norm variable) sometimes explodes and becomes a very large value for certain samples in my dataset. For some reason this will cause the loss to diverge even after applying gradient clipping. I would thus like to define a threshold and if the grad norm is above the threshold, I would like to discard this entire batch (that has a batch size equal to the total batch size after applying gradient accumulation) and don’t do any gradient update. I wonder if I can do something like this:
if accelerator.sync_gradients: grad_norm = accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) if grad_norm.item() > grad_norm_threshold: lr_scheduler.step() optimizer.zero_grad() continue optimizer.step() lr_scheduler.step() optimizer.zero_grad()
I’m not sure if this can correctly skip the update of the current batch. If not, what should I do here? Thanks in advance.