Image reconstruction with diffusion model

Hello the community !

I’m currently working on a personal project using diffusion model to complete image reconstruction. The following is my failed idea, and I would like to ask the possible reasons why my idea didn’t work.

I used the [textual inversion] (Textual Inversion) scripts in diffusers. In the training process of a diffusion model, we use Unet to predict the noise added to an image. Naturally, I use the following code (modified from DDIM) to predict a pseudo x_0 with the noisy image and predicted noise from Unet.

def predict_x_0(scheduler, decoder, pred_noise, latent, timestep):
    # timestep to cpu
    timestep = timestep.cpu().item()

    alpha_prod_t = scheduler.alphas_cumprod[timestep]
    pred_z_0 = latent - (1 - alpha_prod_t) ** (0.5) * pred_noise / alpha_prod_t ** (0.5)
    pred_x_0 = decoder.decode(pred_z_0).sample
    return pred_x_0

Then, I compute the reconstruction loss and then add it into the backpropagation of loss

                # predict x_0
                pred_x_0 = predict_x_0(noise_scheduler, vae, model_pred, latents, timesteps)

                # add reconstruction loss
                recon_loss = F.mse_loss(pred_x_0, batch["pixel_values"].to(dtype=weight_dtype), reduction="mean")

                inversion_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                ttl_loss = inversion_loss + recon_loss


However, the model generated less similar images comparing to the case without this reconstruction loss. (In validation process, I use the prompt “A <kodak_img> picture itself”, with the placeholder token as <kodak_img> during the training process)

I was wondering why my idea couldn’t work and would appreciate any kind advice on this matter !

Thanks in advance !