Text-to-image training loss becomes nan all of a sudden

Description

Hello,

I am trying to finetune the stable diffusion 2.1 model on a custom dataset. I am using the example script provided here. The training starts alright, the loss is decreasing, but then randomly, the loss becomes nan, and the model starts to output black images. It is not an issue with the dataset as this once happened after an entire epoch (so it had iterated over all samples). This is completely random, I once got it to train for 1800 steps before running into the problem, while on average it takes around 200 steps.

I have tried 2 different datasets. One was a custom dataset, of 768x786 images, 2000 data samples. The other one was this fashion dataset from kaggle. I used a subset of around 1800 images from this.

I have tried batch sizes upto 4. I have also tried training with and without the xformers flag.

Any ideas what the problem could be?
Thanks.

The command I am using. (The training script is unchanged)

accelerate launch train_text_to_image.py --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1" --train_data_dir="dataset\dataset_fashion" --resolution=512 --center_crop --train_batch_size=1 --gradient_accumulation_steps=1 --gradient_checkpointing --max_train_steps=10000 --learning_rate=1e-6 --max_grad_norm=0.5 --lr_scheduler="constant" --output_dir="output" --checkpointing_steps=200 --enable_xformers_memory_efficient_attention

System Info

  • GPU: 24GB Titan RTX
  • diffusers version: 0.15.0.dev0
  • Platform: Windows-10-10.0.19044-SP0
  • Python version: 3.9.16
  • PyTorch version (GPU?): 1.13.1+cu116 (True)
  • Huggingface_hub version: 0.13.3
  • Transformers version: 4.27.3
  • Accelerate version: 0.18.0
  • xFormers version: 0.0.16
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

I’ve seen this a couple times. In my case I was able to get around it by lowering the learning rate, but yours is already quite low, so not too sure, unfortunately.

What happens if you up the grad norm to 1.0?

Same results. I was initially trying it at 1.0. I later changed it to 0.5.

did anyone find out what cause the problem?

1 Like

Thanks for sharing.
Peter

Hi, did you figure out what was the problem later? Is it related to FP16?