I’m learning about diffusion and I know in paper the denoising u-net in stablediffusion and other similar methods predicts a denoised image one timestep ahead from the current iteration. I was expecting that for each image in each iteration of the training, the noise scheduler would add N different levels of noise to the image so the denoising U-net learns to predict all the different levels of noise for a given image at each step, until training converges. However, by looking at the code and reimplementing it myself outside of the train_controlnet.py
train loop example I noticed that at each timestep only 1 level of noise is injected to the images in latent space. Is this understanding correct? Is there any particular reason for designing it this way?
Here’s the code I used (based on the loop from train_controlnet.py
to check that indeed there’s only 1 level of corruption added at each iteration, and the denoising u-net also only predicts 1 level of denoising for each image (even if the level it predicts for each image in the same batch may be different due to the randint
):
import torch
from torchvision.utils import save_image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
# https://huggingface.co/docs/diffusers/training/controlnet
model_id = 'runwayml/stable-diffusion-v1-5'
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder='tokenizer')
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
# text
prompts = ['a photograph of an astronaut riding a horse', 'placeholder prompt']
prompts = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
print(prompts.input_ids.shape)
encoder_hidden_states = text_encoder(prompts.input_ids)[0]
print(encoder_hidden_states.shape)
img = torch.rand(2, 3, 512, 512)
latents = vae.encode(img).latent_dist.sample() * vae.config.scaling_factor
print(latents.shape)
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# here is where the noise is added, noise for 1 step is added based on the randomly sampled timestep for each image in the batch
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
print(noisy_latents.shape)
# predicts for only next step
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
print(model_pred.shape)
decoded = vae.decode(model_pred).sample
print(decoded.shape)
save_image(decoded, 'samples.png', nrow=int(decoded.shape[0] ** 0.5))