Unfortunately, it seems like there’s a significant missing piece here.
I thought I had trained on my data, with the class embeddings, but I don’t think I did. Stepping through the code, it looks like the class embeddings will be silently skipped if class_embed_type
isn’t set (yes, you did mention this), but trying to set it manually I crash with the following error:
File "/home/james/anaconda3/envs/riffusion/lib/python3.9/site-packages/torch/nn/modules/module.py", line 987, in convert
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
NotImplementedError: Cannot copy out of meta tensor; no data!
I tried by both setting the class embedding type in the config.json
and adding it when I instantiate the unet, as an argument to from_pretrained()
, but I’m guessing maybe it fails because there are no weights in the diffusion_pytorch_model.bin
for the class embeddings, so it can’t instantiate it.
So perhaps I’m forced to train from scratch… which is actually fine, but how do I do that???
Okay, I think I worked out a way to get started:
unet = UNet2DConditionModel(class_embed_type='timestep')
And I have a feeling this works, because I run out of CUDA memory when trying to process it with my embedding!
(Fortunately I now have access to a bigger GPU, so I’ll give it a try on that…)
But please let me know if there’s another (or a better) way!
Another update. I had mistakenly assumed the unet was using the default values; adding the non-default values (from config.json
) to the init got me further:
unet = UNet2DConditionModel(sample_size=64, cross_attention_dim=768, class_embed_type='timestep')
However, I’m running into problems with shapes when using the timestep
type. I’ve been able to at least get the model training by using identity
, then adding a block in the unet’s forward
to adjust the shape of my custom conditioning embedding, like so:
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if not class_emb.shape == emb.shape:
emb_len = emb.nelement()
cl_emb_len = class_emb.nelement()
if cl_emb_len > emb_len:
# here we can only truncate
class_emb = class_emb[:emb_len]
else:
# here we can repeat, pad, and reshape to match emb
cl_emb_repeat = emb_len // cl_emb_len
cl_em_pad_len = emb_len - (cl_emb_repeat * cl_emb_len)
cl_em_pad = torch.zeros(cl_em_pad_len).to(emb.device)
class_emb = class_emb.repeat(cl_emb_repeat)
class_emb = torch.cat((class_emb, cl_em_pad), 0)
class_emb = class_emb.reshape(emb.shape)
emb = emb + class_emb
This at least allows me to use the class_labels
argument to pass in my (non-class) custom conditioning embedding. If this is clearly a bad idea, any help would be greatly appreciated.