Hi everyone, I’m trying to implement Diffusion v1.5 with Self-attention guidance (Self-Attention Guidance).
I’m writing my own inference pipeline with diffusers and I’m loading the weights for the autoencoder, unet, tokenizer and text_encoder of the model with half precision weights. I’m running this Colab T4 GPU
# Autoencoder- latents into image space
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, subfolder="vae")
# Tokenizer and Text encoder to tokenize and encode the prompt
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, subfolder="text_encoder")
# UNet- generate latents
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, subfolder="unet")
# The noise scheduler
# PNDMScheduler
scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
I’m running this for 50 inference timesteps
While running my denoising loop below, I’m running into OOM while generating latents using unet.
# forward and give guidance
degraded_pred = unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample
This is my denoising loop-
with unet.mid_block.attentions[0].register_forward_hook(get_map_size):
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
# expand the latents to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
# Scale the latents (preconditioning):
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
## sag related
pred_x0 = pred_X0(latents, noise_pred_uncond, t)
# get the stored attention maps
uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
# self-attention-based degrading of latents
degraded_latents = sag_masking(pred_x0, uncond_attn, map_size, t, noise_pred_uncond)
uncond_emb, _ = text_embeddings.chunk(2)
# forward and give guidance
degraded_pred = unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample
noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
##
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
# compute the previous noisy sample x_t -> x_t-1
# latents = scheduler.step(noise_pred, i, latents)["prev_sample"] # Diffusers 0.3 and below
latents = scheduler.step(noise_pred, t, latents).prev_sample
Any help in this regard is much appreciated!