Add additional conditioning info

Hi All,

Does anybody have any guidance as to how/where to add further conditioning info to the HF stable diffusion training/inference pipelines? Everything I’ve read about stable diffusion seems to suggest that multiple different types of conditioning should be possible, but I’m not sure how to integrate it. Since the text embeddings are integrated using self-attention I feel like it should probably be added there, but how? Would I concatenate it to the text embeddings, for example?

Any thoughts appreciated.

Hi @jbmaxwell! That’s an excellent question.

The easiest way, I think, would be to leverage the UNet2DConditionModel and indicate here that you’ll be using custom class embeddings. Similar to what you suspected, these embeddings are simply added to the timestep embeddings. If you use the "timestep" class_embed_type, for example, then you need to pass your custom class labels during the forward pass and then those values are passed through an embedding layer and added to the timestep embeddings.

I hope that’s enough to get you started! Please, do share if it works as well as what you are trying to achieve (if you can make it public).

2 Likes

Excellent, thanks so much @pcuenq!

Okay, I’ve got a bit further…

I’ve trained a VQ-VAE to generate my conditioning embeddings, but I’m wondering whether I can/should pass the (integer) latent code straight in as my “custom class labels”, or if I should/must normalize them first? If I normalize them, is it (0,1), or (-1, 1), or… ? :slight_smile:

Any help appreciated.

—Oh!.. Also, this tensor contains duplicates. Should I remove duplicates? (My concern here is that it will change the shape…)

Hi @pcuenq, I’ve just come back to this to work on today and I think your links above have changed/moved—i.e., the code was maybe updated so they no longer point to the right lines. Just an fyi since the answer might be a bit confusing for future readers (I went through it the other day, so not a huge deal right away). Not sure if there’s a way to avoid this in future… ?

Hi @jbmaxwell!

You are right, I should have used a tag instead of main. Sorry about that.

Since we last talked we’ve added optional class conditioning to UNet2DModel, in addition to what was available in UNet2DConditionModel. The difference is that UNet2DModel is simpler because it doesn’t use text conditioning (for text to image generation). So if you don’t need to train your model for text to image tasks, you can use UNet2DModel instead and training should be faster. This is the revision where that feature was added – and it’s from the PR so it should outlive future changes in main :). You’d use it the same way we discussed:

  • You select a class-conditioning embedding type when you create the UNet.
  • You pass your custom class labels in the forward pass.
2 Likes

This is great, thanks. I will be using both text and this new conditioning info (which I’ll pass via the class-conditioning mechanism), so I’ll stick with UNet2DConditionModel… But it’s cool that UNet2DModel has the option for class-conditioning now, so thanks for the heads-up!

1 Like

Hi again, @pcuenq.

I think I managed to run some training with my additional conditioning info, and now I’m trying to test inference. Is there a straightforward way to use the “class labels” during inference—i.e., in one of the pipelines? I didn’t see anything obvious, so I’ve been working on an adaptation of StableDiffusionPipeline to do it… But It thought I’d ask, in case there’s something simpler I can make use of.

Thanks!

Unfortunately, it seems like there’s a significant missing piece here.

I thought I had trained on my data, with the class embeddings, but I don’t think I did. Stepping through the code, it looks like the class embeddings will be silently skipped if class_embed_type isn’t set (yes, you did mention this), but trying to set it manually I crash with the following error:

File "/home/james/anaconda3/envs/riffusion/lib/python3.9/site-packages/torch/nn/modules/module.py", line 987, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
NotImplementedError: Cannot copy out of meta tensor; no data!

I tried by both setting the class embedding type in the config.json and adding it when I instantiate the unet, as an argument to from_pretrained(), but I’m guessing maybe it fails because there are no weights in the diffusion_pytorch_model.bin for the class embeddings, so it can’t instantiate it.

So perhaps I’m forced to train from scratch… which is actually fine, but how do I do that???


Okay, I think I worked out a way to get started:

unet = UNet2DConditionModel(class_embed_type='timestep')

And I have a feeling this works, because I run out of CUDA memory when trying to process it with my embedding! :rofl:

(Fortunately I now have access to a bigger GPU, so I’ll give it a try on that…)

But please let me know if there’s another (or a better) way!


Another update. I had mistakenly assumed the unet was using the default values; adding the non-default values (from config.json) to the init got me further:

unet = UNet2DConditionModel(sample_size=64, cross_attention_dim=768, class_embed_type='timestep')

However, I’m running into problems with shapes when using the timestep type. I’ve been able to at least get the model training by using identity, then adding a block in the unet’s forward to adjust the shape of my custom conditioning embedding, like so:

class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if not class_emb.shape == emb.shape:
    emb_len = emb.nelement()
    cl_emb_len = class_emb.nelement()
    if cl_emb_len > emb_len:
        # here we can only truncate
        class_emb = class_emb[:emb_len]
    else:
    # here we can repeat, pad, and reshape to match emb
    cl_emb_repeat = emb_len // cl_emb_len
    cl_em_pad_len = emb_len - (cl_emb_repeat * cl_emb_len)
    cl_em_pad = torch.zeros(cl_em_pad_len).to(emb.device)
    class_emb = class_emb.repeat(cl_emb_repeat)
    class_emb = torch.cat((class_emb, cl_em_pad), 0)
    class_emb = class_emb.reshape(emb.shape)
                
emb = emb + class_emb

This at least allows me to use the class_labels argument to pass in my (non-class) custom conditioning embedding. If this is clearly a bad idea, any help would be greatly appreciated.

2 Likes

Okay, some real progress!

I trained a model with this type of conditioning and it does seem to be working. However, although it’s difficult to say for certain, I seem to be getting less influence from my custom conditioning that I would like. Basically, the text seems to have much more impact than my conditioning, and I’m wondering how to balance things out.

One thing I’d thought of was to move my conditioning from being added to the time embedding, emb, to being added to the text embedding, encoder_hidden_states, perhaps adding a parameter to adjust the “mix” of the two. I may try this anyway, but if anybody has any thoughts, please share.

On that note, @pcuenq, I realize I’m not really clear on the roles/functions of the time embedding and the text embedding. Intuitively, it seems to me that the time embedding is related to the basic task of generating anything, and impacts directly on the denoising process, whereas the text embedding is an additional feature used to kind of “focus” the generation in the latent space. Is that roughly correct?

Hi @jbmaxwell! Congrats on making progress on this task!

I think your intuition is correct. The time embeddings provide a hint to the model about the step in the (de)noising process we are. Because timesteps are semantically related to one another (they follow a progression, so 4 is a time instance larger than 3 but smaller than 5), they are encoded using a fancy method that tries to preserve that relationship - those are the sinusoidal embeddings that you’d probably have seen in the code.

Depending on the nature of your additional conditioning, you may not need to capture a similar relationship on your data, and that’s probably why you didn’t see great results when using the timestep conditioning type, which applies the same sinusoidal method to your custom conditioning data.

For example, if you were training a model to generate 5 different classes of objects, the numerical representations of those 5 categories do not bear any relationship with one another. In this case, you might want to explore the None class_embed_type, but indicate that your num_class_embeds is 5. (None is probably not a good choice for this use-case, as it appears that only timestep or identity are supported, but it’s actually a third choice you can use). If you use this method, your model will learn to differentiate about those 5 categories, and then you can request to generate one of your desired subjects by supplying the class information at inference time.

Let us know if that’s something that sounds useful for your project! :slight_smile:

1 Like

Thanks for the info. Very helpful!

1 Like

Hi, have you successfully made adding conditional embedding working ? if it works, do you mind to share the script? thank you.

Hi, thanks for all of these discussions. I have one question: for the conditional text embedding, can I replace it as image embedding ( for instance, I would like to replace image A to the part of image B which is already generated without text input. ) Hope my question is clear.

I did get a version of this to “work”, but the effect was pretty subtle. It did seem to do something, but not what I was after, and the result was overwhelmingly dominated by the text prompt… I don’t think I have the code for that anymore, as I re-wrote that script with a version that added to the text embedding—which was spectacularly bad, so I abandoned the effort. :joy:

You should have a look into ControlNet for what it sounds like you’re trying to do. I think there’s a ton of room for experimenting with different types of conditioning using that approach.

1 Like

Thank, I will read more and ask again if I have any more questions.

Hello, I also have four different classes that I want to train. Here, my num_class_embedds is set to 4 and class_embed_type is set to None. However, I’m having trouble writing the class_labels , which is causing an error in the line hidden_states = hidden_states + temb . Can you please tell me how to create the class_labels ?

This is my class_labels code
def class_label_tensor(examples, is_train=True):

    def class_tokenizer(text):
        class_names = [['C0201'], ['R0201'], ['L2016'], ['F1210']]
        class_label = text 
        num_classes = len(class_names)
        class_vector = torch.zeros(num_classes, dtype=torch.int)
        class_index = class_names.index(class_label)
        class_vector[class_index] = 1
        class_tensor = class_vector.view(1, num_classes)
        return class_tensor
    
    captions = []
    for caption in examples[caption_column]:
        if isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            # take a random caption if there are multiple
            captions.append(random.choice(caption) if is_train else caption[0])
        else:
            raise ValueError(
                f"Caption column `{caption_column}` should contain either strings or lists of strings."
            )
    label_tensor = class_tokenizer(captions)
    return label_tensor

I always get RuntimeError: The size of tensor a (64) must match the size of tensor b (320) at non-singleton dimension 4in my case.

Thx!

@pcuenq I am trying to make an EEG to Image model, my EEG encoder is a separate model and I intend to use Stable Diffusion without text conditioning, the idea is I’ll map the EEGs to their corresponding images. Would you please guide me in this regard, where and how do I attach this encoder model?

1 Like

how about added_cond_kwargs , can we pass the embeddings we have to make another condition here what do you think ?

@pcuenq

Hello, I’m curious if you ever made progress on this idea? I am looking to tackle a similar idea for fMRI, where I will train a new encoder (brain → embedding) end to end with the diffusion model that I am fine tuning to reconstruct the original image with my conditioning info. Let me know if you have any insights on this front.