Add additional conditioning info

Hi All,

Does anybody have any guidance as to how/where to add further conditioning info to the HF stable diffusion training/inference pipelines? Everything I’ve read about stable diffusion seems to suggest that multiple different types of conditioning should be possible, but I’m not sure how to integrate it. Since the text embeddings are integrated using self-attention I feel like it should probably be added there, but how? Would I concatenate it to the text embeddings, for example?

Any thoughts appreciated.

Hi @jbmaxwell! That’s an excellent question.

The easiest way, I think, would be to leverage the UNet2DConditionModel and indicate here that you’ll be using custom class embeddings. Similar to what you suspected, these embeddings are simply added to the timestep embeddings. If you use the "timestep" class_embed_type, for example, then you need to pass your custom class labels during the forward pass and then those values are passed through an embedding layer and added to the timestep embeddings.

I hope that’s enough to get you started! Please, do share if it works as well as what you are trying to achieve (if you can make it public).

1 Like

Excellent, thanks so much @pcuenq!

Okay, I’ve got a bit further…

I’ve trained a VQ-VAE to generate my conditioning embeddings, but I’m wondering whether I can/should pass the (integer) latent code straight in as my “custom class labels”, or if I should/must normalize them first? If I normalize them, is it (0,1), or (-1, 1), or… ? :slight_smile:

Any help appreciated.

—Oh!.. Also, this tensor contains duplicates. Should I remove duplicates? (My concern here is that it will change the shape…)

Hi @pcuenq, I’ve just come back to this to work on today and I think your links above have changed/moved—i.e., the code was maybe updated so they no longer point to the right lines. Just an fyi since the answer might be a bit confusing for future readers (I went through it the other day, so not a huge deal right away). Not sure if there’s a way to avoid this in future… ?

Hi @jbmaxwell!

You are right, I should have used a tag instead of main. Sorry about that.

Since we last talked we’ve added optional class conditioning to UNet2DModel, in addition to what was available in UNet2DConditionModel. The difference is that UNet2DModel is simpler because it doesn’t use text conditioning (for text to image generation). So if you don’t need to train your model for text to image tasks, you can use UNet2DModel instead and training should be faster. This is the revision where that feature was added – and it’s from the PR so it should outlive future changes in main :). You’d use it the same way we discussed:

  • You select a class-conditioning embedding type when you create the UNet.
  • You pass your custom class labels in the forward pass.
1 Like

This is great, thanks. I will be using both text and this new conditioning info (which I’ll pass via the class-conditioning mechanism), so I’ll stick with UNet2DConditionModel… But it’s cool that UNet2DModel has the option for class-conditioning now, so thanks for the heads-up!

1 Like