I think that method is correct, but there seem to be reports of tensor mismatch issues when training with FSDP.