I’m learning about diffusion and I know in paper the denoising u-net in stablediffusion and other similar methods predicts a denoised image one timestep ahead from the current iteration. I was expecting that for each image in each iteration of the training, the noise scheduler would add N different levels of noise to the image so the denoising U-net learns to predict all the different levels of noise for a given image at each step, until training converges. However, by looking at the code and reimplementing it myself outside of the
train_controlnet.py train loop example I noticed that at each timestep only 1 level of noise is injected to the images in latent space. Is this understanding correct? Is there any particular reason for designing it this way?
Here’s the code I used (based on the loop from
train_controlnet.py to check that indeed there’s only 1 level of corruption added at each iteration, and the denoising u-net also only predicts 1 level of denoising for each image (even if the level it predicts for each image in the same batch may be different due to the
import torch from torchvision.utils import save_image from transformers import CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel # https://huggingface.co/docs/diffusers/training/controlnet model_id = 'runwayml/stable-diffusion-v1-5' tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") # text prompts = ['a photograph of an astronaut riding a horse', 'placeholder prompt'] prompts = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") print(prompts.input_ids.shape) encoder_hidden_states = text_encoder(prompts.input_ids) print(encoder_hidden_states.shape) img = torch.rand(2, 3, 512, 512) latents = vae.encode(img).latent_dist.sample() * vae.config.scaling_factor print(latents.shape) noise = torch.randn_like(latents) bsz = latents.shape timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # here is where the noise is added, noise for 1 step is added based on the randomly sampled timestep for each image in the batch noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) print(noisy_latents.shape) # predicts for only next step model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample print(model_pred.shape) decoded = vae.decode(model_pred).sample print(decoded.shape) save_image(decoded, 'samples.png', nrow=int(decoded.shape ** 0.5))