Hi @jbmaxwell! That’s an excellent question.
The easiest way, I think, would be to leverage the UNet2DConditionModel
and indicate here that you’ll be using custom class embeddings. Similar to what you suspected, these embeddings are simply added to the timestep embeddings. If you use the "timestep"
class_embed_type
, for example, then you need to pass your custom class labels during the forward
pass and then those values are passed through an embedding layer and added to the timestep embeddings.
I hope that’s enough to get you started! Please, do share if it works as well as what you are trying to achieve (if you can make it public).