Skip optimizer update when gradient norm is large with Accelerate gradient accumulation

I’m trying to finetune a stable diffusion model using the example finetuning code in huggingface diffusers (https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) The main training loop uses accelerate gradient accumulation and is similar to this pseudo code example:

global_step = 0.0
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
    with accelerator.accumulate(unet):
        # steps for generating model inputs (noisy latents, timesteps, etc.)
        ...
        # model forward
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
    
        # Gather the losses across all processes for logging (if we use distributed training).
        avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
        train_loss += avg_loss.item() / args.gradient_accumulation_steps
    
        # Backpropagate
        accelerator.backward(loss)
        if accelerator.sync_gradients:
            grad_norm = accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    if accelerator.sync_gradients:
        global_step += 1
        train_loss = 0.0

I noticed that the gradient norm (stored in the grad_norm variable) sometimes explodes and becomes a very large value for certain samples in my dataset. For some reason this will cause the loss to diverge even after applying gradient clipping. I would thus like to define a threshold and if the grad norm is above the threshold, I would like to discard this entire batch (that has a batch size equal to the total batch size after applying gradient accumulation) and don’t do any gradient update. I wonder if I can do something like this:

if accelerator.sync_gradients:
    grad_norm = accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
    if grad_norm.item() > grad_norm_threshold:
        lr_scheduler.step()
        optimizer.zero_grad()
        continue
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

I’m not sure if this can correctly skip the update of the current batch. If not, what should I do here? Thanks in advance.