I am trying to guide a model output towards a certain class however adapting the DDPMPipeline to use UNetConditionModel and additionally adding a guide method to the DDPMPipeline does not seem to be guiding the process at all.
My class labels are in the form [0,0,0,0,1,0,0,0,1]. I pass these through an embedding so that they can fit and be used as hidden_states in the UNetConditionalModel during training and inference.
I have also added the guiding logic to multiply the conditional & unconditional noise prediction by a guidance scale.
@torch.no_grad()
def guide(self, guidance_scale, batch_size=1, generator=None, torch_device=None, output_type="pil", hidden_states=None):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
# Sample gaussian noise to begin loop
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = latents.to(torch_device)
# set step values
self.scheduler.set_timesteps(1000)
no_condition_attrs = torch.tensor([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]).long().repeat(eval_batch_size_,1).to("cuda")
y = (emb(no_condition_attrs))
for t in tqdm(self.scheduler.timesteps):
# 1. predict noise noise_pred
# latents_input = torch.cat([latents] * 2)
# context = torch.cat([[y], [hidden_states]])
noise_prediction_text = self.unet(image, t, hidden_states)["sample"]
noise_pred_uncond = self.unet(image, t, y)["sample"]
# noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(noise_pred, t, image)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
The full code is in a colab repo here.
This setup however leaves me unable to guide the output towards an intended label in any way.
Anyone have any ideas?