Does anybody have any guidance as to how/where to add further conditioning info to the HF stable diffusion training/inference pipelines? Everything I’ve read about stable diffusion seems to suggest that multiple different types of conditioning should be possible, but I’m not sure how to integrate it. Since the text embeddings are integrated using self-attention I feel like it should probably be added there, but how? Would I concatenate it to the text embeddings, for example?
I’ve trained a VQ-VAE to generate my conditioning embeddings, but I’m wondering whether I can/should pass the (integer) latent code straight in as my “custom class labels”, or if I should/must normalize them first? If I normalize them, is it (0,1), or (-1, 1), or… ?
Any help appreciated.
—Oh!.. Also, this tensor contains duplicates. Should I remove duplicates? (My concern here is that it will change the shape…)
Hi @pcuenq, I’ve just come back to this to work on today and I think your links above have changed/moved—i.e., the code was maybe updated so they no longer point to the right lines. Just an fyi since the answer might be a bit confusing for future readers (I went through it the other day, so not a huge deal right away). Not sure if there’s a way to avoid this in future… ?
You are right, I should have used a tag instead of main. Sorry about that.
Since we last talked we’ve added optional class conditioning to UNet2DModel, in addition to what was available in UNet2DConditionModel. The difference is that UNet2DModel is simpler because it doesn’t use text conditioning (for text to image generation). So if you don’t need to train your model for text to image tasks, you can use UNet2DModel instead and training should be faster. This is the revision where that feature was added – and it’s from the PR so it should outlive future changes in main :). You’d use it the same way we discussed:
You select a class-conditioning embedding type when you create the UNet.
You pass your custom class labels in the forward pass.
This is great, thanks. I will be using both text and this new conditioning info (which I’ll pass via the class-conditioning mechanism), so I’ll stick with UNet2DConditionModel… But it’s cool that UNet2DModel has the option for class-conditioning now, so thanks for the heads-up!