I am trying to use the encoder part of VAE, but it cannot backward() on GPU. Surprisingly, it works on the CPU, and here is my code and message.
def vae():
'''
image tensor with dim of (1,3,h,w)
'''
# url = "/mnt/traffic/home/dongziping/WatermarkAttack/weights/vae-ft-mse-840000-ema-pruned.safetensors" # can also be a local file
model = AutoencoderKL.from_pretrained("/mnt/traffic/home/dongziping/sdxl-vae")
return model.encoder
if __name__ == "__main__":
import torch
device = torch.device("cuda")
model = vae().to(device)
model.eval()
image = torch.randn((1,3,28,28)).to(device)
latent = model(image)
loss = torch.sum(latent)
loss.backward()
print("loss:",loss)