from io import BytesIO
import requests
import torch
from PIL import Image
from diffusers import DiffusionPipeline, DDIMScheduler
import numpy as np
import PIL
import torch
from packaging import version
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# ------------------------------------------------------------------------------
def preprocess(image):
w, h = image.size
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def get_text_embeddings(prompt, tokenizer, text_encoder, device):
text_input = tokenizer(prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt", )
text_embeddings_orig = text_encoder(text_input.input_ids.to(device))[0]
return text_embeddings_orig
def get_latents(image, vae, generator, dtype, device):
image = preprocess(image)
image = image.to(device=device, dtype=dtype)
init_latent_image_dist = vae.encode(image).latent_dist
image_latents = init_latent_image_dist.sample(generator=generator)
image_latents = 0.18215 * image_latents
return image_latents
def get_pipeline_components(pipe):
vae = pipe.vae # the VAE model
unet = pipe.unet # the U-Net model
text_encoder = pipe.text_encoder # the text encoder model
tokenizer = pipe.tokenizer # the tokenizer
scheduler = pipe.scheduler # the scheduler
return vae, unet, text_encoder, tokenizer, scheduler
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
# Set the device
device = get_device()
# Load the pipeline
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
safety_checker=None,
torch_dtype=torch.float16,
).to(device)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
# Load the image and text prompt
generator = torch.Generator("cuda").manual_seed(0)
seed = 0
prompt = "A photo of Barack Obama smiling with a big grin"
url = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1'
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
#image = pipe(prompt, generator=generator).images[0]
# Get the different components from the pipeline
vae, unet, text_encoder, tokenizer, scheduler = get_pipeline_components(pipe)
vae = vae.to(device).eval()
unet = unet.to(device).eval()
text_encoder = text_encoder.to(device).eval()
# Encode the text prompt and get the text embeddings
text_embeddings = get_text_embeddings(prompt, tokenizer, text_encoder, device=device)
text_embeddings = torch.nn.Parameter(text_embeddings, requires_grad=True)
text_embeddings_orig = text_embeddings.clone()
# Initialize the optimizer
optimizer = torch.optim.Adam(
[text_embeddings], # only optimize the embeddings
lr=0.000001,
)
# Get Latent representation of the image
image_latents = get_latents(init_image, vae, generator, text_embeddings.dtype, device=device)
# Optimize the text embeddings to better reconstruct the init image
for _ in range(500):
noise = torch.randn(image_latents.shape, dtype=text_embeddings.dtype).to(image_latents.device)
timesteps = torch.randint(1000, (1,), device=image_latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = scheduler.add_noise(image_latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, text_embeddings).sample
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(loss.item())
0.03802490234375
โญโโโโโโโโโโโโโโโโโโโโโ Traceback (most recent call last) โโโโโโโโโโโโโโโโโโโโโโโฎ
โ /home/hashmat/Code/Projects/Diffusion_Attack/imagic/main.py:74 in โ
โ โ
โ 71 โ โ
โ 72 โ loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="n โ
โ 73 โ โ
โ โฑ 74 โ loss.backward() โ
โ 75 โ โ
โ 76 โ optimizer.step() โ
โ 77 โ optimizer.zero_grad() โ
โ โ
โ /home/hashmat/miniconda3/envs/pix2pix/lib/python3.11/site-packages/torch/_te โ
โ nsor.py:487 in backward โ
โ โ
โ 484 โ โ โ โ create_graph=create_graph, โ
โ 485 โ โ โ โ inputs=inputs, โ
โ 486 โ โ โ ) โ
โ โฑ 487 โ โ torch.autograd.backward( โ
โ 488 โ โ โ self, gradient, retain_graph, create_graph, inputs=inputs โ
โ 489 โ โ ) โ
โ 490 โ
โ โ
โ /home/hashmat/miniconda3/envs/pix2pix/lib/python3.11/site-packages/torch/aut โ
โ ograd/init.py:200 in backward โ
โ โ
โ 197 โ # The reason we repeat same the comment below is that โ
โ 198 โ # some Python versions print out the first line of a multi-line fu โ
โ 199 โ # calls in the traceback and some print out the last line โ
โ โฑ 200 โ Variable.execution_engine.run_backward( # Calls into the C++ eng โ
โ 201 โ โ tensors, grad_tensors, retain_graph, create_graph, inputs, โ
โ 202 โ โ allow_unreachable=True, accumulate_grad=True) # Calls into th โ
โ 203 โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
RuntimeError: Trying to backward through the graph a second time (or directly
access saved tensors after they have already been freed). Saved intermediate
values of the graph are freed when you call .backward() or autograd.grad().
Specify retain_graph=True if you need to backward through the graph a second
time or if you need to access saved tensors after calling backward.
Process finished with exit code 1
I am trying to optimize the text embedding in order to make the embedding output the given image. However, it raises the above error. However, I am not doing the backward pass on the loss twice. Can someone please explain the issue here?