Hi, I’m working on a research program and I have some difficulty.
I’m trying to ‘find’ the image CLIP embedding that creates a target image that I have - I believe that it is called ‘image inversion’.
I want to do it in this way:
start with some random vector of the CLIP image space.
use an image decoder to create an image,
perform an MSE loss between the generated image and the target image,
and backpropagate this to the input, such that the random vector will change toward the real image CLIP embedding.
the problem is that i can’t backpropagate properly, i get this error:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I think the computational graph is disconnected from the process of generating the image with the encoder.
what should i do?
this is my code:
from torch.optim import Adam
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
import sys
from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline
import torch
import PIL
import torch
from diffusers.utils import load_image
from torchvision import transforms
from transformers import CLIPVisionModelWithProjection
from diffusers.models import UNet2DConditionModel
import numpy as np
DEVICE = torch.device(‘cuda:0’)
unet = UNet2DConditionModel.from_pretrained(
‘kandinsky-community/kandinsky-2-2-decoder’,
subfolder=‘unet’
).half().to(DEVICE)
decoder = KandinskyV22Pipeline.from_pretrained(
‘kandinsky-community/kandinsky-2-2-decoder’,
unet=unet,
torch_dtype=torch.float16,
output_type = ‘np’
).to(DEVICE)
prior = KandinskyV22PriorPipeline.from_pretrained(
‘kandinsky-community/kandinsky-2-2-prior’,
image_encoder=image_encoder,
torch_dtype=torch.float16
).to(DEVICE)
negative_prior_prompt =‘lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature’
negative_emb = prior(
prompt=negative_prior_prompt,
num_inference_steps=25,
num_images_per_prompt=1
)
transform = transforms.ToTensor()
Initialize the random vector (image embedding)
embedding = torch.randn(1, image_encoder.config.projection_dim, requires_grad=True, device=DEVICE)
Define an optimizer
optimizer = Adam([embedding], lr=0.01)
Load the target image and preprocess it
image_path = ‘/content/drive/MyDrive/datasets/ILSVRC/Data/CLS-LOC/val_with_labels/n01443537/ILSVRC2012_val_00000236.JPEG’
init_image = Image.open(image_path).convert(‘RGB’)
init_image = init_image.resize((512, 512))
target_image = transform(init_image)
target_image = target_image.to(DEVICE).half()
Define the number of iterations
num_iterations = 1000
for _ in range(num_iterations):
# Generate an image from the embedding
generated_image = decoder(
image_embeds=embedding,
negative_image_embeds=negative_emb.image_embeds,
num_inference_steps=75,
height=512,
width=512,
output_type = ‘np’,
return_dict = False)
generated_image_tensor = torch.from_numpy(generated_image[0])
generated_image_tensor = generated_image_tensor.squeeze(0).permute(2, 0, 1).to(DEVICE)
# Calculate the loss
loss = F.mse_loss(generated_image_tensor, target_image)
# Backpropagate the loss and update the embedding
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Optionally, print the loss or save intermediate results
print(f"Iteration {_}, Loss: {loss.item()}")
The error i get:
RuntimeError Traceback (most recent call last)
in <cell line: 23>()
40 # Backpropagate the loss and update the embedding
41 optimizer.zero_grad()
—> 42 loss.backward()
43 optimizer.step()
44
2 frames
/usr/local/lib/python3.10/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
523 inputs=inputs,
524 )
→ 525 torch.autograd.backward(
526 self, gradient, retain_graph, create_graph, inputs=inputs
527 )
/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
265 # some Python versions print out the first line of a multi-line function
266 # calls in the traceback and some print out the last line
→ 267 engine_run_backward(
268 tensors,
269 grad_tensors,
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py in _engine_run_backward(t_outputs, *args, **kwargs)
742 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
743 try:
→ 744 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
745 t_outputs, *args, **kwargs
746 ) # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn