2B Model Fill Up Memory Usage on 4xA100s

With 4xA100s, I am training a 2B model (Gemma-2-2b-it) to perform SequenceClassification on a custom dataset. My code is quite boilerplate, but what I don’t understand is why would model training fill up the GPU memory? As I see it, With a model like gemma-2b, the total number of parameters is 2B. If we assume full precision training, i.e. float32, then the total GPU memory required to hold the model is 2B*(32/8=4 bytes/parameter) = 8GB of memory. When the model undergoes training, there is also optimizer states, gradients, and optimizer intermediates that add into the total memory consumption. Generally, the GPU memory required to train a 2B model at full precision is ~5*8 = 40GB of GPU memory. If my understanding is correct, then how can the model nearly take up ~250GB of GPU memory? FYI I use accelerate with DeepSpeed ZeRO 3

with open("all.json", 'r') as f:
    all_data = json.load(f)
dataset = load_dataset('json', data_files={'train': 'debug.json', 'test': 'test.json'})
classes = list(set(d['label'] for d in all_data))
id2label = {i: str(c) for i, c in enumerate(classes)}
label2id = {str(c): i for i, c in enumerate(classes)}

model_path = "gemma2"

tokenizer = AutoTokenizer.from_pretrained(model_path)

def preprocess_function(example):

   labels = [1. if classes[i] == str(example["label"]) else 0. for i in range(len(classes))]

   example = tokenizer(example["text"], truncation=True)
   example['label'] = labels
   return example

tokenized_dataset = dataset.map(preprocess_function)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

model = AutoModelForSequenceClassification.from_pretrained(
    model_path,
    num_labels=len(classes),
    id2label=id2label,
    label2id=label2id,
    problem_type = "multi_label_classification"
)

training_args = TrainingArguments(

   output_dir="model",
   per_device_train_batch_size=1,
   per_device_eval_batch_size=1,
   gradient_accumulation_steps=4,
   num_train_epochs=5,
   weight_decay=0.01,
   eval_strategy="epoch",
   save_strategy="epoch",
   logging_strategy="steps",
   logging_steps=100
)

trainer = Trainer(

   model=model,
   args=training_args,
   train_dataset=tokenized_dataset["train"],
   eval_dataset=tokenized_dataset["test"],
   tokenizer=tokenizer,
   data_collator=data_collator
)

trainer.train()
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.03             Driver Version: 550.144.03     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100 80GB PCIe          On  |   00000001:00:00.0 Off |                    0 |
| N/A   42C    P0             79W /  300W |   74815MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          On  |   00000002:00:00.0 Off |                    0 |
| N/A   38C    P0             69W /  300W |   81031MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100 80GB PCIe          On  |   00000003:00:00.0 Off |                    0 |
| N/A   39C    P0             81W /  300W |   54311MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100 80GB PCIe          On  |   00000004:00:00.0 Off |                    0 |
| N/A   40C    P0             81W /  300W |   54311MiB /  81920MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    561923      C   /home/azureuser/miniconda/bin/python        74806MiB |
|    1   N/A  N/A    561924      C   /home/azureuser/miniconda/bin/python        81022MiB |
|    2   N/A  N/A    561925      C   /home/azureuser/miniconda/bin/python        54302MiB |
|    3   N/A  N/A    561926      C   /home/azureuser/miniconda/bin/python        54302MiB |
+-----------------------------------------------------------------------------------------+
1 Like

Could it be that DeepSpeed has a bug…?