Gradient Issue after using torch.no_grad

Hi, I am training a diffusion model. I am facing a weird issue where the gradient tracking is not working.

If I run this code segement

noise_pred = self.unet(
    latent_model_input,
    t,
    encoder_hidden_states=prompt_embeds,
    cross_attention_kwargs=cross_attention_kwargs,
    return_dict=False,
)[0]

I get the output for noise_pred.requires_grad as True. But if I run this instead

with torch.no_grad():
    noise_pred = self.unet(
        latent_model_input,
        t,
        encoder_hidden_states=prompt_embeds,
        cross_attention_kwargs=cross_attention_kwargs,
        return_dict=False,
    )[0]
  
noise_pred = self.unet(
    latent_model_input,
    t,
    encoder_hidden_states=prompt_embeds,
    cross_attention_kwargs=cross_attention_kwargs,
    return_dict=False,
)[0]
print(noise_pred.requires_grad)

I get the result to be False. It seems that using the Unet one time under the torch.no_grad context turns of gradient tracking somehow. I checked the requires_grad for parameters in the Unet and it was True.

Thanks in advance

1 Like

with torch.no_grad():

It seems that tensors originating from the block with this with statement may continue to be affected by no_grad. In addition, there appear to be cases where the effect remains due to caching…
https://stackoverflow.com/questions/63785319/pytorch-torch-no-grad-versus-requires-grad-false

1 Like