How to train noise? (The model is frozen)

I tried to freeze the model and train the input noise only, but failed. Here is a simple example that relies on the MSE loss of output images and labels to update noise:

model_id = "stabilityai/stable-diffusion-2-base"
model = AutoPipelineForText2Image.from_pretrained(model_id,torch_dtype=torch.float32).to("cuda")
model.unet.eval()
model.vae.eval()

for param in model.unet.parameters():
    param.requires_grad = False

for param in model.vae.parameters():
    param.requires_grad = False

height = model.unet.config.sample_size * model.vae_scale_factor
width = model.unet.config.sample_size * model.vae_scale_factor
cur_noise = torch.randn([1, model.unet.config.in_channels, height // model.vae_scale_factor, width // model.vae_scale_factor]).cuda()

cur_noise = nn.Parameter(cur_noise, requires_grad=True)
optimizer = torch.optim.Adam([cur_noise], lr=args.lr)

criterion = torch.nn.MSELoss()

image = model('dog', latents=cur_noise, output_type='pt', generator=torch.manual_seed(0)).images[0]
image_ = torch.unsqueeze(image, dim=0)

image_224 = F.interpolate(image_, size=(224,224), mode='bilinear', align_corners=False)
labels = torch.zeros_like(image_224).cuda()

loss = criterion(image_224, labels)
print(loss.grad_fn)  # None

optimizer.zero_grad()
loss.backward()
optimizer.step()

loss.grad_fn The output is None, and the following error message is displayed:

Traceback (most recent call last):
  File "D:\Exp\AIGC\DOV\mytest.py", line 118, in <module>
    loss.backward()
  File "D:\anaconda3\envs\py39-torch111\lib\site-packages\torch\_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "D:\anaconda3\envs\py39-torch111\lib\site-packages\torch\autograd\__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

How can I solve this problem?

you cannot disable grad if you need to backpropagate in your model.
So remove those lines

for param in model.unet.parameters():
    param.requires_grad = False

for param in model.vae.parameters():
    param.requires_grad = False

You might need gradient checkpointing if you face memory issues

1 Like