Hi,
Would someone know how to help with this issue? Thanks!
If in the following code snipped, I change the device from CPU to Cuda it throws the following error:
import torch
from diffusers import AutoencoderKL
device = "cuda"
vae_2 = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32,)
vae_2.train()
classes = 2
vae_2.to(device)
latents = torch.randn(1, 4, 64, 64).to(torch.float32).to(device)
decoded_mask = vae_2.decode(latents, return_dict=False)[0]
decoded_mask.requires_grad_(True)
target = torch.randn(decoded_mask.shape).to(device)
loss = torch.nn.MSELoss()(decoded_mask, target)
loss.backward()
Error:
β β
β 25 β
β 26 print(loss) β
β 27 β
β β± 28 loss.backward() β
β 29 β
β 30 print(vae_2.decoder.conv_out.weight.grad) β
β 31 print(vae_2.decoder.conv_out.bias.grad) β
β β
β /home/nkondapa/anaconda3/envs/neurips2023_env/lib/python3.8/site-packages/to β
β rch/_tensor.py:487 in backward β
β β
β 484 β β β β create_graph=create_graph, β
β 485 β β β β inputs=inputs, β
β 486 β β β ) β
β β± 487 β β torch.autograd.backward( β
β 488 β β β self, gradient, retain_graph, create_graph, inputs=inputs β
β 489 β β ) β
β 490 β
β β
β /home/nkondapa/anaconda3/envs/neurips2023_env/lib/python3.8/site-packages/to β
β rch/autograd/__init__.py:200 in backward β
β β
β 197 β # The reason we repeat same the comment below is that β
β 198 β # some Python versions print out the first line of a multi-line fu β
β 199 β # calls in the traceback and some print out the last line β
β β± 200 β Variable._execution_engine.run_backward( # Calls into the C++ eng β
β 201 β β tensors, grad_tensors_, retain_graph, create_graph, inputs, β
β 202 β β allow_unreachable=True, accumulate_grad=True) # Calls into th β
β 203 β
β°βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ―
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so
the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.