TL;DR Why doesnβt Acclerate/FSDP seem to be doing much of anything to reduce memory in the following?
Iβm trying to get some hands-on and learn how to run large models across multiple nodes and/or GPUs. Iβm starting with Trainer/accelerate/FSDP2 and planning to work up from there but I think Iβm missing something.
python 3.12.9
torch 2.7.0
transformers 4.52.4
accelerate 1.7.0
My βtoyβ program to train an βemptyβ model:
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers import DefaultDataCollator, DataCollatorForLanguageModeling
from transformers import TrainingArguments, Trainer
import os
model_dir = 'NousResearch/Llama-3.2-1B'
TRACE = False
N = 2048
context_length = 64
batch_size = 64
def load_datasets() :
train_data_list = [
{"text" : "The quick brown fox jumped over the lazy dog's back t{:06d}".format(i)} for i in range(4*N)
]
eval_data_list = [
{"text" : "The quick brown fox jumped over the lazy dog's back e{:06d}".format(i)} for i in range(N)
]
datasets = DatasetDict ( # create datasets dict train and eval
{ 'train': Dataset.from_list(train_data_list),
'eval' : Dataset.from_list(eval_data_list)}
)
return datasets
def load_tokenizer(model_dir) :
tokenizer = AutoTokenizer.from_pretrained(model_dir)
return tokenizer
def load_model(model_dir) :
# get just the config from the pretrained directory
config = AutoConfig.from_pretrained(model_dir)
model = AutoModelForCausalLM.from_config(config)
return model
def mytrain(model_dir) :
def tokenize(dataset) :
return tokenizer(dataset['text'], padding='max_length', max_length=context_length, return_length=True)
##
raw_datasets = load_datasets()
if TRACE : print("dataset\n", raw_datasets)
##
tokenizer = load_tokenizer(model_dir)
if TRACE : print("tokenizer\n", tokenizer)
##
tokenizer.pad_token = tokenizer.eos_token
tokenized_datasets = raw_datasets.map(
tokenize, batched=True, remove_columns=raw_datasets["train"].column_names)
if TRACE : print("tokenized_datasets\n", tokenized_datasets)
##
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
if TRACE :
example_collated = data_collator([tokenized_datasets["train"][i] for i in range(3)])
print("example_collated\n", example_collated)
##
training_args = TrainingArguments( # do this before model load for FSDP?
output_dir="outputs/",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=10,
logging_strategy="epoch",
eval_strategy="epoch",
save_strategy="no",
push_to_hub=False,
disable_tqdm=True,
deepspeed=None,
)
##
model = load_model(model_dir) # do the after TrainingArguments which sets up some stuff?
if TRACE : print("model\n", model)
##
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["eval"],
processing_class=tokenizer,
data_collator=data_collator,
)
trainer.train()
from datasets.utils.logging import disable_progress_bar
import torch
if __name__ == "__main__" :
disable_progress_bar()
mytrain(
model_dir=model_dir
)
torch.distributed.destroy_process_group()
I first run my test progam as simple python/pytorch; single GPU without accelerate.
[gpu2:training] CUDA_VISIBLE_DEVICES=0 python 05_acctest.py
{'loss': 0.8924, 'grad_norm': 0.8125, 'learning_rate': 4.50390625e-05, 'epoch': 1.0}
{'eval_loss': 2.5442957878112793, 'eval_runtime': 2.4496, 'eval_samples_per_second': 836.064, 'eval_steps_per_second': 13.063, 'epoch': 1.0}
{'loss': 0.6293, 'grad_norm': 0.65234375, 'learning_rate': 4.00390625e-05, 'epoch': 2.0}
{'eval_loss': 2.6600184440612793, 'eval_runtime': 2.4495, 'eval_samples_per_second': 836.094, 'eval_steps_per_second': 13.064, 'epoch': 2.0}
.
.
.
{'loss': 0.6061, 'grad_norm': 0.4921875, 'learning_rate': 3.90625e-08, 'epoch': 10.0}
{'eval_loss': 2.8240463733673096, 'eval_runtime': 2.4496, 'eval_samples_per_second': 836.055, 'eval_steps_per_second': 13.063, 'epoch': 10.0}
{'train_runtime': 333.183, 'train_samples_per_second': 245.871, 'train_steps_per_second': 3.842, 'train_loss': 0.6405227959156037, 'epoch': 10.0}
While itβs running I use nvidia-smi to look at the memory used
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 21181 C python 21372MiB |
+-----------------------------------------------------------------------------------------+
Thatβs at least in the ball-park for what accelerate estimates:
[gpu2:training] accelerate estimate-memory NousResearch/Llama-3.2-1B
Loading pretrained config for `NousResearch/Llama-3.2-1B` from `transformers`...
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Memory Usage for loading `NousResearch/Llama-3.2-1B` β
βββββββββ¬ββββββββββββββ¬βββββββββββ¬ββββββββββββββββββββββββ€
β dtype βLargest LayerβTotal Sizeβ Training using Adam β
βββββββββΌββββββββββββββΌβββββββββββΌββββββββββββββββββββββββ€
βfloat32β 1002.0 MB β 4.6 GB β 18.42 GB β
βfloat16β 501.0 MB β 2.3 GB β 9.21 GB β
β int8 β 250.5 MB β 1.15 GB β N/A β
β int4 β 125.25 MB β589.28 MB β N/A β
βββββββββ΄ββββββββββββββ΄βββββββββββ΄ββββββββββββββββββββββββ
Next I use βaccelerate configβ to generate a config file for 2 GPUs using FSDP2. (mostly with default values)
[gpu2:training] cat 1n2gfsdp_defaults.yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Using that file an running with accelerateβ¦
[gpu2:training] CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file 1n2gfsdp_defaults.yaml 05_acctest.py
{'loss': 1.0797, 'grad_norm': 0.6328125, 'learning_rate': 4.5078125000000006e-05, 'epoch': 1.0}
{'eval_loss': 2.5193161964416504, 'eval_runtime': 1.376, 'eval_samples_per_second': 1488.383, 'eval_steps_per_second': 11.628, 'epoch': 1.0}
{'loss': 0.6584, 'grad_norm': 0.4609375, 'learning_rate': 4.0078125e-05, 'epoch': 2.0}
{'eval_loss': 2.5891079902648926, 'eval_runtime': 1.3771, 'eval_samples_per_second': 1487.218, 'eval_steps_per_second': 11.619, 'epoch': 2.0}
.
.
.
{'loss': 0.6096, 'grad_norm': 0.462890625, 'learning_rate': 7.8125e-08, 'epoch': 10.0}
{'eval_loss': 2.754133462905884, 'eval_runtime': 1.3776, 'eval_samples_per_second': 1486.605, 'eval_steps_per_second': 11.614, 'epoch': 10.0}
{'train_runtime': 178.9799, 'train_samples_per_second': 457.705, 'train_steps_per_second': 3.576, 'train_loss': 0.6661747217178344, 'epoch': 10.0}
β¦ nvidia-smi memory during the computationβ¦
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 24421 C ...AI/training-4.52.4/bin/python 21384MiB |
| 1 N/A N/A 24422 C ...AI/training-4.52.4/bin/python 21388MiB |
+-----------------------------------------------------------------------------------------+
Next a config file with 4 GPUsβ¦
[gpu2:training] cat 1n4gfsdp_defaults.yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
β¦ execute using accelerateβ¦
[gpu2:training] CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file 1n4gfsdp_defaults.yaml 05_acctest.py
{'loss': 1.373, 'grad_norm': 0.458984375, 'learning_rate': 4.515625e-05, 'epoch': 1.0}
{'eval_loss': 2.402463912963867, 'eval_runtime': 0.6972, 'eval_samples_per_second': 2937.372, 'eval_steps_per_second': 11.474, 'epoch': 1.0}
{'loss': 0.7474, 'grad_norm': 0.435546875, 'learning_rate': 4.0156250000000004e-05, 'epoch': 2.0}
{'eval_loss': 2.3128156661987305, 'eval_runtime': 0.6946, 'eval_samples_per_second': 2948.607, 'eval_steps_per_second': 11.518, 'epoch': 2.0}
.
.
.
{'loss': 0.6214, 'grad_norm': 0.30078125, 'learning_rate': 1.5625e-07, 'epoch': 10.0}
{'eval_loss': 2.432434320449829, 'eval_runtime': 0.694, 'eval_samples_per_second': 2950.801, 'eval_steps_per_second': 11.527, 'epoch': 10.0}
{'train_runtime': 89.6101, 'train_samples_per_second': 914.182, 'train_steps_per_second': 3.571, 'train_loss': 0.718875628709793, 'epoch': 10.0}
β¦ nvidia-smi while executingβ¦
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 25570 C ...AI/training-4.52.4/bin/python 20526MiB |
| 1 N/A N/A 25571 C ...AI/training-4.52.4/bin/python 20146MiB |
| 2 N/A N/A 25572 C ...AI/training-4.52.4/bin/python 20146MiB |
| 3 N/A N/A 25573 C ...AI/training-4.52.4/bin/python 20146MiB |
+-----------------------------------------------------------------------------------------+
Clearly something is happening; Iβm getting a performance benefit from using more GPUs (almost linear!). But, Iβm not seeing a substantial improvement in memory usage.
- Is my config file missing something? Are there better parameters that facilitate memory savings?
- Can I somehow get accelerate to dump what it thinks itβs doing (vs. what I specified in the config file)?
- Can I somehow dump the wrapped model to see what FSDP has done?
===============================================================
I did a similar experiment with bloom-3b just to see if it made any difference, and things still seem strange.
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 37058 C python 74748MiB |
+-----------------------------------------------------------------------------------------+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Memory Usage for loading `bigscience/bloom-3b` β
βββββββββ¬ββββββββββββββ¬βββββββββββ¬ββββββββββββββββββββ€
β dtype βLargest LayerβTotal SizeβTraining using Adamβ
βββββββββΌββββββββββββββΌβββββββββββΌββββββββββββββββββββ€
βfloat32β 2.39 GB β 11.19 GB β 44.74 GB β
βfloat16β 1.2 GB β 5.59 GB β 22.37 GB β
β int8 β 612.5 MB β 2.8 GB β N/A β
β int4 β 306.25 MB β 1.4 GB β N/A β
βββββββββ΄ββββββββββββββ΄βββββββββββ΄ββββββββββββββββββββ
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 251138 C ...AI/training-4.52.4/bin/python 53922MiB |
| 1 N/A N/A 251139 C ...AI/training-4.52.4/bin/python 53538MiB |
| 2 N/A N/A 251140 C ...AI/training-4.52.4/bin/python 53538MiB |
| 3 N/A N/A 251141 C ...AI/training-4.52.4/bin/python 53538MiB |
+-----------------------------------------------------------------------------------------+