hi.
you can use model DP(Data Parallel) or DDP(Distributed Data parallel) to load huge model at Multi GPUs.
https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
regards.