Does ControlNet (and other diffusers) only include 1 noise injection per iteration in training loop?

I’m learning about diffusion and I know in paper the denoising u-net in stablediffusion and other similar methods predicts a denoised image one timestep ahead from the current iteration. I was expecting that for each image in each iteration of the training, the noise scheduler would add N different levels of noise to the image so the denoising U-net learns to predict all the different levels of noise for a given image at each step, until training converges. However, by looking at the code and reimplementing it myself outside of the train_controlnet.py train loop example I noticed that at each timestep only 1 level of noise is injected to the images in latent space. Is this understanding correct? Is there any particular reason for designing it this way?
Here’s the code I used (based on the loop from train_controlnet.py to check that indeed there’s only 1 level of corruption added at each iteration, and the denoising u-net also only predicts 1 level of denoising for each image (even if the level it predicts for each image in the same batch may be different due to the randint ):

import torch
from torchvision.utils import save_image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel

# https://huggingface.co/docs/diffusers/training/controlnet
model_id = 'runwayml/stable-diffusion-v1-5'

tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder='tokenizer')
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")

# text
prompts = ['a photograph of an astronaut riding a horse', 'placeholder prompt']
prompts = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
print(prompts.input_ids.shape)

encoder_hidden_states = text_encoder(prompts.input_ids)[0]
print(encoder_hidden_states.shape)

img = torch.rand(2, 3, 512, 512)

latents = vae.encode(img).latent_dist.sample() * vae.config.scaling_factor
print(latents.shape)

noise = torch.randn_like(latents)

bsz = latents.shape[0]

timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

# here is where the noise is added, noise for 1 step is added based on the randomly sampled timestep for each image in the batch
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
print(noisy_latents.shape)

# predicts for only next step
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
print(model_pred.shape)

decoded = vae.decode(model_pred).sample
print(decoded.shape)

save_image(decoded, 'samples.png', nrow=int(decoded.shape[0] ** 0.5))
1 Like

Yes, you are understanding it correctly.
Why is it done like this? Here is my reasoning:

  1. When randomizing the noise level (the timesteps value), it forces the model to work with all kinds of noise levels. In theory, the timesteps value doesn’t even need to be an int. It can be a float.So if you are trying to cram all the noise levels into one batch, you will fail as you can’t exhaustively sample all floating point values. In current implementation, it’s an int value between 0 and 1000. But even with only 1000 levels, you can’t put 1000 samples into one batch, can you?
  2. Because you can’t put all noise levels into one batch, you have to pick some levels. Let’s say you can put 24 samples in a batch, would you choose the first 24 timesteps or would you randomize the timesteps? Clearly the uniform randomization is better because it would yield a gradient estimate that is more representative of the whole noise range between 0 and 1000. It’s the same reason why we shuffle training samples when training.
  3. Why not include 24 different images (each having one noise level) in a training batch, instead of including 24 noise levels from a single image? The first option obviously gives a better gradient estimate of the whole dataset.

In short, it’s all about estimating gradient better and thinking about trade-offs of what to put into the batch. Remember that training every noise level is costly and you don’t have unlimited batch space. We want gradients that are estimating the whole dataset, not a single image.

2 Likes