I currently load T5 with:
self.t5_enc = T5EncoderModel.from_pretrained(T5_MODEL).eval().to(self.device)
But, I’m not sure if this using optimizations like torch.nn.utils.skip_init
, and loading a FSDP sharded-checkpoint for maximum loading speed.
How can I ensure these optimizations are being enabled?