Trouble using Torchmetrics with Accelerate in stable diffusion finetuning script

Hello,

I’m trying to add an FID metric to the log_validation function of the stable diffusion training script under diffusers/examples/text_to_image. I’m following Evaluating Diffusion Models and using the torchmetrics fid. However, even with ~8GB remaining GPU memory, the call to fid.compute() hangs without completing for no obvious reason. All validation code only runs on the main process (accelerator.is_main_process()) and I’ve tried manually moving the inputs and fid object to the appropriate cuda device.

The validation code works fine if I run it standalone, so I suspect it’s something to do with Accelerate, even though I’m manually placing the tensors and fid model. Any ideas?

Update: when I force close the training, I see: “[Rank 1] NCCL watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=224, OpType=_ALLGATHER_BASE, NumelIn=2, NumelOut=8, Timeout(ms)=1800000) ran for 1800596 milliseconds before timing out.”

So something is causing that NCCL call to hang…

1 Like

Hi, did you manage to solve this? I met the same issue. Thanks!

1 Like

No luck. As a workaround I ended up moving that FID validation code to a standalone process. While it’s less convenient, this decoupling has the benefit of reducing GPU memory & compute load during your accelerate-based training (assuming you have another GPU you can run this FID calculation on).

1 Like

TL;DR pass “sync_on_compute=False” in when instantiating the metric class e.g.
FrechetInceptionDistance(feature=self.cfg.inception_feature, normalize=True, sync_on_compute=False).to(device)

I had the same issue and managed to solve it by passing “sync_on_compute=False” to the class when I instantiated it. FID is meant to work in distributed mode, so even if you are only calling compute on your main rank (or Rank 0), it still knows it’s in a multi-gpu setting and tries to sync the results from all GPUs. This seems to be true for all metrics that inherit from the Metric() class in torchmetrics.metric, not just FID :slight_smile:

1 Like