Training from scratch

Hi All,

I’m wondering if anybody has any experiences to share on training from scratch? I’m finding I can get fairly decent results, in a reasonable amount of time, if my dateset is small, but as I increase the size of the dataset things get worse… or maybe they just require much more training (enough that I start to wonder if I’m getting anywhere). This is a bit counter-intuitive to me, as I’d expect the model to be able to gather more information from the larger dataset. On the other hand, the smaller dataset is obviously able to run more epochs in the same number of steps, so I guess it benefits more from seeing the same data again?

Also, in the original latent diffusion paper they talk about “converging”, but I’ve no idea what that really means in this case, as the loss doesn’t appear to give any meaningful feedback during training. In my experience, it drops to about 0.12 quickly, then just bounces around at the level and doesn’t improve. (I will note that on smaller datasets it does drop further—to about 0.07 in my case—but that still happens quite early in training) My understanding, mostly from poking around various forums and Discord servers, is that this is pretty common. But if it is, then is there any way of gauging progress during training?

Finally, I’ve been using the train_text_to_image.py script, mostly for convenience, just loading a “fresh” model instead of the pretrained model… But perhaps there are reasons why this isn’t a good idea?

Any thoughts or experiences appreciated.

The loss surface when training diffusion models is quite uninformative IMO. A good analysis of this is available here: [2303.09556] Efficient Diffusion Training via Min-SNR Weighting Strategy

On the other hand, the smaller dataset is obviously able to run more epochs in the same number of steps, so I guess it benefits more from seeing the same data again?

I’d think so but have you checked if the model overfits the data too quickly if you do that. This is something we have continuously observed in our experiments. Cc: @pcuenq @valhalla

1 Like

Thanks for the response!

I’ll definitely have a look over that paper—thanks for posting.

I do wonder about overfitting, though I’m not exactly sure how best to identify when it happens. Do diffusion models have anything like the failure modes GANs encounter? Or is it more just poor generalization? (Of course, GAN failure modes like mode collapse aren’t quite like overfitting, but they are very particular kinds of training failures.)

Regarding the overfitting part, we usually run validation inference in between training to see how well the model is progressing.

Regarding failure modes, I think it’s something the community needs to investigate deeper. Maybe an extensive paper from Karras et al., like they have done many times in the past but for GANs:

1 Like

Actually, I’m also curious whether you’ve noticed whether there’s any pattern to how the effectiveness/accuracy of prompts develops during training? For example, does the accuracy of prompting evolve at a similar rate to image quality, or is it a refinement that happens at a later stage of training?

What do you mean prompt accuracy? Something like CLIP score?

Sorry, no, I just meant subjectively whether the prompt seems to be doing what it should.

this isn’t training from scratch, but from hacking around and experimenting with the EveryDream2 stable diffusion fine-tuner i was able to make a fairly useful and reliable loss graph by holding the noise seed fixed when running a (fixed-set/sequence) validation pass.

the intuition is that because the diffusion process relies on noise so heavily, variance in that noise between validation passes tends to overwhelm the relatively small signal of decreasing loss. to correct for this, re-seed the noise to the same seed every time you do a validation pass (i used isolate_rng() context manager to prevent also re-seeding the train RNG, iirc it’s in pytorch lightning). you’re still at the mercy of whatever sequence of noises that particular seed used for validation gives you, but you should find you have a loss curve that traces a more clearly decreasing trajectory (even if it’s just a small one).

fwiw, contra to the link @sayakpaul provided, this loss curve is informative - it pretty reliably indicates when fine-tuning loss has reached a minima, and can be trusted to start to trend upward in a way that’s reflective of the model overfitting the training data.

example: https://huggingface.co/damian0815/pashahlis-val-test-1e-6-ep30

i’m surprised no other stable diffusion fine-tuners have implemented this. also a bit suspicious…

1 Like

Super interesting. Thanks for posting!