Add additional conditioning info

Unfortunately, it seems like there’s a significant missing piece here.

I thought I had trained on my data, with the class embeddings, but I don’t think I did. Stepping through the code, it looks like the class embeddings will be silently skipped if class_embed_type isn’t set (yes, you did mention this), but trying to set it manually I crash with the following error:

File "/home/james/anaconda3/envs/riffusion/lib/python3.9/site-packages/torch/nn/modules/module.py", line 987, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
NotImplementedError: Cannot copy out of meta tensor; no data!

I tried by both setting the class embedding type in the config.json and adding it when I instantiate the unet, as an argument to from_pretrained(), but I’m guessing maybe it fails because there are no weights in the diffusion_pytorch_model.bin for the class embeddings, so it can’t instantiate it.

So perhaps I’m forced to train from scratch… which is actually fine, but how do I do that???


Okay, I think I worked out a way to get started:

unet = UNet2DConditionModel(class_embed_type='timestep')

And I have a feeling this works, because I run out of CUDA memory when trying to process it with my embedding! :rofl:

(Fortunately I now have access to a bigger GPU, so I’ll give it a try on that…)

But please let me know if there’s another (or a better) way!


Another update. I had mistakenly assumed the unet was using the default values; adding the non-default values (from config.json) to the init got me further:

unet = UNet2DConditionModel(sample_size=64, cross_attention_dim=768, class_embed_type='timestep')

However, I’m running into problems with shapes when using the timestep type. I’ve been able to at least get the model training by using identity, then adding a block in the unet’s forward to adjust the shape of my custom conditioning embedding, like so:

class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if not class_emb.shape == emb.shape:
    emb_len = emb.nelement()
    cl_emb_len = class_emb.nelement()
    if cl_emb_len > emb_len:
        # here we can only truncate
        class_emb = class_emb[:emb_len]
    else:
    # here we can repeat, pad, and reshape to match emb
    cl_emb_repeat = emb_len // cl_emb_len
    cl_em_pad_len = emb_len - (cl_emb_repeat * cl_emb_len)
    cl_em_pad = torch.zeros(cl_em_pad_len).to(emb.device)
    class_emb = class_emb.repeat(cl_emb_repeat)
    class_emb = torch.cat((class_emb, cl_em_pad), 0)
    class_emb = class_emb.reshape(emb.shape)
                
emb = emb + class_emb

This at least allows me to use the class_labels argument to pass in my (non-class) custom conditioning embedding. If this is clearly a bad idea, any help would be greatly appreciated.

3 Likes