Optimizing text embedding raises issue of trying to backward through the graph second time

from io import BytesIO

import requests
import torch
from PIL import Image
from diffusers import DiffusionPipeline, DDIMScheduler
import numpy as np
import PIL
import torch
from packaging import version


if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
    PIL_INTERPOLATION = {
        "linear": PIL.Image.Resampling.BILINEAR,
        "bilinear": PIL.Image.Resampling.BILINEAR,
        "bicubic": PIL.Image.Resampling.BICUBIC,
        "lanczos": PIL.Image.Resampling.LANCZOS,
        "nearest": PIL.Image.Resampling.NEAREST,
    }
else:
    PIL_INTERPOLATION = {
        "linear": PIL.Image.LINEAR,
        "bilinear": PIL.Image.BILINEAR,
        "bicubic": PIL.Image.BICUBIC,
        "lanczos": PIL.Image.LANCZOS,
        "nearest": PIL.Image.NEAREST,
    }
# ------------------------------------------------------------------------------

def preprocess(image):
    w, h = image.size
    w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0


def get_text_embeddings(prompt, tokenizer, text_encoder, device):
    text_input = tokenizer(prompt,
                           padding="max_length",
                           max_length=tokenizer.model_max_length,
                           truncation=True,
                           return_tensors="pt", )

    text_embeddings_orig = text_encoder(text_input.input_ids.to(device))[0]

    return text_embeddings_orig

def get_latents(image, vae, generator, dtype, device):
    image = preprocess(image)
    image = image.to(device=device, dtype=dtype)
    init_latent_image_dist = vae.encode(image).latent_dist
    image_latents = init_latent_image_dist.sample(generator=generator)
    image_latents = 0.18215 * image_latents
    return image_latents

def get_pipeline_components(pipe):
    vae  = pipe.vae # the VAE model
    unet = pipe.unet # the U-Net model
    text_encoder = pipe.text_encoder # the text encoder model
    tokenizer = pipe.tokenizer # the tokenizer
    scheduler = pipe.scheduler # the scheduler
    return vae, unet, text_encoder, tokenizer, scheduler

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

# Set the device
device = get_device()

# Load the pipeline
pipe = DiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    safety_checker=None,
    torch_dtype=torch.float16,
).to(device)

pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)




# Load the image and text prompt
generator = torch.Generator("cuda").manual_seed(0)
seed = 0
prompt = "A photo of Barack Obama smiling with a big grin"
url = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1'
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))

#image = pipe(prompt, generator=generator).images[0]



# Get the different components from the pipeline
vae, unet, text_encoder, tokenizer, scheduler = get_pipeline_components(pipe)
vae = vae.to(device).eval()
unet = unet.to(device).eval()
text_encoder = text_encoder.to(device).eval()

# Encode the text prompt and get the text embeddings
text_embeddings = get_text_embeddings(prompt, tokenizer, text_encoder, device=device)
text_embeddings = torch.nn.Parameter(text_embeddings, requires_grad=True)
text_embeddings_orig = text_embeddings.clone()

# Initialize the optimizer
optimizer = torch.optim.Adam(
    [text_embeddings],  # only optimize the embeddings
    lr=0.000001,
)


# Get Latent representation of the image
image_latents = get_latents(init_image, vae, generator, text_embeddings.dtype, device=device)

# Optimize the text embeddings to better reconstruct the init image
for _ in range(500):
    noise = torch.randn(image_latents.shape, dtype=text_embeddings.dtype).to(image_latents.device)
    timesteps = torch.randint(1000, (1,), device=image_latents.device)

    # Add noise to the latents according to the noise magnitude at each timestep
    # (this is the forward diffusion process)
    noisy_latents = scheduler.add_noise(image_latents, noise, timesteps)

    # Predict the noise residual
    noise_pred = unet(noisy_latents, timesteps, text_embeddings).sample

    loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    print(loss.item())





0.03802490234375
โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ /home/hashmat/Code/Projects/Diffusion_Attack/imagic/main.py:74 in โ”‚
โ”‚ โ”‚
โ”‚ 71 โ”‚ โ”‚
โ”‚ 72 โ”‚ loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="n โ”‚
โ”‚ 73 โ”‚ โ”‚
โ”‚ โฑ 74 โ”‚ loss.backward() โ”‚
โ”‚ 75 โ”‚ โ”‚
โ”‚ 76 โ”‚ optimizer.step() โ”‚
โ”‚ 77 โ”‚ optimizer.zero_grad() โ”‚
โ”‚ โ”‚
โ”‚ /home/hashmat/miniconda3/envs/pix2pix/lib/python3.11/site-packages/torch/_te โ”‚
โ”‚ nsor.py:487 in backward โ”‚
โ”‚ โ”‚
โ”‚ 484 โ”‚ โ”‚ โ”‚ โ”‚ create_graph=create_graph, โ”‚
โ”‚ 485 โ”‚ โ”‚ โ”‚ โ”‚ inputs=inputs, โ”‚
โ”‚ 486 โ”‚ โ”‚ โ”‚ ) โ”‚
โ”‚ โฑ 487 โ”‚ โ”‚ torch.autograd.backward( โ”‚
โ”‚ 488 โ”‚ โ”‚ โ”‚ self, gradient, retain_graph, create_graph, inputs=inputs โ”‚
โ”‚ 489 โ”‚ โ”‚ ) โ”‚
โ”‚ 490 โ”‚
โ”‚ โ”‚
โ”‚ /home/hashmat/miniconda3/envs/pix2pix/lib/python3.11/site-packages/torch/aut โ”‚
โ”‚ ograd/init.py:200 in backward โ”‚
โ”‚ โ”‚
โ”‚ 197 โ”‚ # The reason we repeat same the comment below is that โ”‚
โ”‚ 198 โ”‚ # some Python versions print out the first line of a multi-line fu โ”‚
โ”‚ 199 โ”‚ # calls in the traceback and some print out the last line โ”‚
โ”‚ โฑ 200 โ”‚ Variable.execution_engine.run_backward( # Calls into the C++ eng โ”‚
โ”‚ 201 โ”‚ โ”‚ tensors, grad_tensors
, retain_graph, create_graph, inputs, โ”‚
โ”‚ 202 โ”‚ โ”‚ allow_unreachable=True, accumulate_grad=True) # Calls into th โ”‚
โ”‚ 203 โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
RuntimeError: Trying to backward through the graph a second time (or directly
access saved tensors after they have already been freed). Saved intermediate
values of the graph are freed when you call .backward() or autograd.grad().
Specify retain_graph=True if you need to backward through the graph a second
time or if you need to access saved tensors after calling backward.

Process finished with exit code 1

I am trying to optimize the text embedding in order to make the embedding output the given image. However, it raises the above error. However, I am not doing the backward pass on the loss twice. Can someone please explain the issue here?

Hello, Iโ€™m doing something similar to your case (optimizing the text embedding) and Iโ€™m getting the same error. Did you manage to solve it?