Hi everyone,
I am training LLMs on a 8*A100-80GB DGX and due to context length I am being bottlenecked by memory, i.e. I use per-device batch size = 1 and gradient accumulation up to the target batch size. At this stage PEFT isn’t an option, it is full parameter training.
So far I have been using gradient checkpointing to push the context length at the expense of decrease in training speed, as it should be expected.
Next I have been experimenting with DeepSpeed Zero2/3 in hope to reduce memory consumption with less decrease in training speed. In practice this has not been successful so I am asking for advice on how I may refine DeepSpeed settings or what may be an issue I have overlooked please…
I am observing that the GPU usage is intermittent when using DeepSpeed, whereas for gradient checkpointing it is consistently at 100%. I assume that CPU offloading is a key factor of GPU memory saving, by transferring back and forth data to RAM but the Zero implementation should optimize that by interleaving computation and communication in way to keep GPUs busy… right?
As a result, currently I can train faster and with longer context length using gradient checkpointing over DeepSpeed Zero … and Zero 3 is slower than Zero 2 as it adds more communication overhead (?)
I would have expected the opposite, i.e. gradient checkpointing < zero 2 < zero 3 (in terms of speed)
Here is my configuration passed to the trainer argument “deepspeed=dsc” with zero_stage in {2, 3}
dsc = {
“train_batch_size”: “auto”,
“train_micro_batch_size_per_gpu”: “auto”,
“gradient_accumulation_steps”: “auto”,
“gradient_clipping”: “auto”,
“zero_allow_untested_optimizer”: True,
“bf16”: {
“enabled”: “auto”
},
“zero_optimization”: {
“stage”: zero_stage,
“contiguous_gradients”: True,
“stage3_max_live_parameters”: 1e9,
“stage3_max_reuse_distance”: 1e9,
“stage3_prefetch_bucket_size”: 1e7,
“stage3_param_persistence_threshold”: 1e5,
“reduce_bucket_size”: 1e7,
“sub_group_size”: 1e9,
“offload_optimizer”: {
“device”: “cpu”,
“pin_memory”: True
},
“offload_param”: {
“device”: “cpu”,
“pin_memory”: True
}
},
“activation_checkpointing”: {
“partition_activations”: True,
“contiguous_memory_optimization”: True,
“cpu_checkpointing”: True
}
}
Do you spot anything I shall refine or other things to look into please?