How can I load large models like google/mt5-xl on a GPU

Hi all

Did any able to load large models like google/mt5-xl on a GPU instance?

Although the model size does not exceed 15GB, I failed to train it (with batch-size=1) using a GPU instance with A6000 (48 GB GPU memory)

Did anyone do it before?

Any explanation why it consumes GPU memory more than its actual size?

Thank you all

Hi Abu

Loading Transformer models requires significantly more GPU memory than just the model size:

  • For starters, you need at least 2x the model size, once for the initial weights and once to load the checkpoint
  • Apart from the model parameters, there are also the gradients, optimizer states, and activations taking memory, so the actual memory usage will be likely more than 4x the model size

In order to train such large models you will have to use some sort of model parallelism, as explained in this blog post.

You can find an example on how to train GPT-J (-24GB) here: amazon-sagemaker-examples/train_gptj_smp_notebook.ipynb at main · aws/amazon-sagemaker-examples · GitHub
(this example uses Amazon SageMaker to distribute the model over several GPUs).

Hope that helps.


Thanks, @marshmellow77 for your detailed answer.