OOM error while generating latents from Unet for Self attention guidelines diffusion

Hi everyone, I’m trying to implement Diffusion v1.5 with Self-attention guidance (Self-Attention Guidance).

I’m writing my own inference pipeline with diffusers and I’m loading the weights for the autoencoder, unet, tokenizer and text_encoder of the model with half precision weights. I’m running this Colab T4 GPU

# Autoencoder- latents into image space
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, subfolder="vae")

# Tokenizer and Text encoder to tokenize and encode the prompt
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, subfolder="text_encoder")

# UNet- generate latents
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, subfolder="unet")

# The noise scheduler
# PNDMScheduler
scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

I’m running this for 50 inference timesteps
While running my denoising loop below, I’m running into OOM while generating latents using unet.

# forward and give guidance
degraded_pred = unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample

This is my denoising loop-

with unet.mid_block.attentions[0].register_forward_hook(get_map_size):
        for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
            # expand the latents to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            # Scale the latents (preconditioning):
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            ## sag related
            pred_x0 = pred_X0(latents, noise_pred_uncond, t)

            # get the stored attention maps
            uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
            # self-attention-based degrading of latents
            degraded_latents = sag_masking(pred_x0, uncond_attn, map_size, t, noise_pred_uncond)
            uncond_emb, _ = text_embeddings.chunk(2)
            # forward and give guidance
            degraded_pred = unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample
            noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
            ##
            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, t, latents).prev_sample


            # compute the previous noisy sample x_t -> x_t-1
            # latents = scheduler.step(noise_pred, i, latents)["prev_sample"] # Diffusers 0.3 and below
            latents = scheduler.step(noise_pred, t, latents).prev_sample

Any help in this regard is much appreciated!