This category is for any question related to the Accelerate library. You can also file an issue.
Does Accelerate
still kick in for validation (forward passes) too?
or its probably just my training snippet
#Performing Validation and loggign out images
if epoch % 2 == 0: #i % 100 == 0
model.eval()
#--------------VALIDATION------------------
for i, (img, label) in enumerate(val_loader):
img.to(device)
model.to(device)
with torch.no_grad():
out, latent_loss = model(img)
val_recon_loss = criterion(out, img)
val_latent_loss = latent_loss.mean()
val_loss = recon_loss + latent_loss_weight * latent_loss\
val_mse_sum += recon_loss.item() * img.shape[0]
val_mse_n += img.shape[0]
wandb.log({"epoch": epoch+1, "val_mse": val_recon_loss.item(),
"val_latent_loss": val_latent_loss.item(), "val_avg_mse": (val_mse_sum/ val_mse_n),
"lr": lr})
loader.set_description(
(
f'epoch: {epoch + 1}; val_mse: {val_recon_loss.item():.5f};'
f'val_latent: {val_latent_loss.item():.3f}; val_avg_mse: {val_mse_sum / val_mse_n:.5f};'
f'lr: {lr:.5f}'
))
model.train()
Clearly, during my validation stage its using CPU for some reason…