How can I speedup T5 load?

I currently load T5 with:

self.t5_enc = T5EncoderModel.from_pretrained(T5_MODEL).eval().to(self.device)

But, I’m not sure if this using optimizations like torch.nn.utils.skip_init, and loading a FSDP sharded-checkpoint for maximum loading speed.

How can I ensure these optimizations are being enabled?

1 Like

I’ve never used FDSP so I don’t know…