Hi,
Would someone know how to help with this issue? Thanks!
If in the following code snipped, I change the device from CPU to Cuda it throws the following error:
import torch
from diffusers import AutoencoderKL
device = "cuda"
vae_2 = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32,)
vae_2.train()
classes = 2
vae_2.to(device)
latents = torch.randn(1, 4, 64, 64).to(torch.float32).to(device)
decoded_mask = vae_2.decode(latents, return_dict=False)[0]
decoded_mask.requires_grad_(True)
target = torch.randn(decoded_mask.shape).to(device)
loss = torch.nn.MSELoss()(decoded_mask, target)
loss.backward()
Error:
โ โ
โ 25 โ
โ 26 print(loss) โ
โ 27 โ
โ โฑ 28 loss.backward() โ
โ 29 โ
โ 30 print(vae_2.decoder.conv_out.weight.grad) โ
โ 31 print(vae_2.decoder.conv_out.bias.grad) โ
โ โ
โ /home/nkondapa/anaconda3/envs/neurips2023_env/lib/python3.8/site-packages/to โ
โ rch/_tensor.py:487 in backward โ
โ โ
โ 484 โ โ โ โ create_graph=create_graph, โ
โ 485 โ โ โ โ inputs=inputs, โ
โ 486 โ โ โ ) โ
โ โฑ 487 โ โ torch.autograd.backward( โ
โ 488 โ โ โ self, gradient, retain_graph, create_graph, inputs=inputs โ
โ 489 โ โ ) โ
โ 490 โ
โ โ
โ /home/nkondapa/anaconda3/envs/neurips2023_env/lib/python3.8/site-packages/to โ
โ rch/autograd/__init__.py:200 in backward โ
โ โ
โ 197 โ # The reason we repeat same the comment below is that โ
โ 198 โ # some Python versions print out the first line of a multi-line fu โ
โ 199 โ # calls in the traceback and some print out the last line โ
โ โฑ 200 โ Variable._execution_engine.run_backward( # Calls into the C++ eng โ
โ 201 โ โ tensors, grad_tensors_, retain_graph, create_graph, inputs, โ
โ 202 โ โ allow_unreachable=True, accumulate_grad=True) # Calls into th โ
โ 203 โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so
the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.