I think this might be related to a NaN loss going into the FP16 scaler? ref
Not sure why the scaler wouldn’t catch that and skip the batch
edit: caught a few NaN batches going into the self.scaler.scale(loss).backward()
step, but I’ve since also seen the error triggered by normal loss values