### System Info
- `transformers` version: 4.37.2
- Platform: Linux-5.14.0-16…2.6.1.el9_1.x86_64-x86_64-with-glibc2.34
- Python version: 3.11.7
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.26.1
- Deepspeed version: 0.13.1
- Flash-attention version: 2.5.2
- Datasets version: 2.16.1
- PyTorch version (GPU?): 2.1.2+cu118 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: Yes
### Who can help?
@pacman100
### Information
- [ ] The official example scripts
- [X] My own modified scripts
### Tasks
- [ ] An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)
- [X] My own task or dataset (give details below)
### Reproduction
I am further pre-training Llama2-7b-chat-hf on a 3,273,686,325 token corpus of my own data. However, training fails at seemingly inconsistent times.
My cluster contains GPU nodes with 4 x A100-80GB GPUs. The out of memory error occurs at seemingly inconsistent times depending on how many GPUs are used.
Here is the training script:
```
import datasets
import os
import torch
import argparse
from mpi4py import MPI
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM
from transformers import DataCollatorForSeq2Seq, default_data_collator
torch.backends.cuda.matmul.allow_tf32 = True
def set_mpi(masteradd):
"""
Set Open MPI environment variables
:param masteradd: Value for setting MASTER_ADDR environment variable
:type masteradd: String
:return: None
"""
comm = MPI.COMM_WORLD
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["RANK"] = str(comm.Get_rank())
os.environ['WORLD_SIZE'] = str(comm.Get_size())
os.environ["MASTER_ADDR"] = masteradd
os.environ["MASTER_PORT"] = "9978"
def main():
"""
Set training parameters and train model
:return: None
"""
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--master_add", dest="masteradd")
args = parser.parse_args()
set_mpi(args.masteradd)
experiment_name = ""
tokenizer_name = 'resized_tokenizer/'
model_name = 'llama2-7b-chat-hf/'
out_dir = 'out/'
os.makedirs(out_dir, exist_ok=True)
dataset_path = "datasets/"
dataset_files = [os.path.join(dataset_path,x) for x in os.listdir(dataset_path)]
dataset = datasets.load_dataset('json', data_files=dataset_files, split='train', cache_dir="cache/")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
training_args = TrainingArguments(
output_dir=out_dir,
deepspeed='multi_node_7b.json',
do_eval=False,
logging_strategy="steps",
logging_steps=10,
learning_rate=2e-5,
warmup_steps=1000,
gradient_checkpointing=False,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
tf32=True,
bf16=True,
weight_decay=0.1,
save_total_limit=40,
push_to_hub=False,
save_strategy="steps",
num_train_epochs=1,
save_steps=1000,
report_to="tensorboard"
)
model=AutoModelForCausalLM.from_pretrained(model_name,
do_sample=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16)
trainer=Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=DataCollatorForSeq2Seq(tokenizer)
)
trainer.train(
resume_from_checkpoint = False,
)
trainer.save_model()
if __name__ == "__main__":
main()
```
Here is the Deepspeed config:
```
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 1,
"offload_optimizer": {
"device": "none"
},
"offload_param": {
"device": "none"
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": "auto"
},
"gradient_accumulation_steps": 4,
"gradient_clipping": "auto",
"gradient_checkpointing": false,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"steps_per_print": 200,
"wall_clock_breakdown": false
}
```
I launch training from a bash script. Here is the relevant line.
```
deepspeed -H hostfile --master_port 9978 --master_addr $PARENT --no_ssh_check --launcher OPENMPI --launcher_args '--oversubscribe ' deepspeed_7b_finetune.py -m $PARENT
```
```
19%|█▉ | 3237/16700 [3:34:12<38:35:22, 10.32s/it]Traceback (most recent call last):
File "/home/user/Hope-Alpha/src/scripts/deepspeed_7b_finetune.py", line 87, in <module>
main()
File "/home/user/Hope-Alpha/src/scripts/deepspeed_7b_finetune.py", line 80, in main
trainer.train(
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/transformers/trainer.py", line 1539, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/transformers/trainer.py", line 1869, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/transformers/trainer.py", line 2772, in training_step
loss = self.compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/transformers/trainer.py", line 2795, in compute_loss
outputs = model(**inputs)
^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1842, in forward
loss = self.module(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1070, in forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 795, in forward
hidden_states = self.input_layernorm(hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/miniconda3/envs/train-transformers/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 116, in forward
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 116.00 MiB. GPU 3 has a total capacty of 79.32 GiB of which 101.56 MiB is free. Including non-PyTorch memory, this process has 79.22 GiB memory in use. Of the allocated memory 75.96 GiB is allocated by PyTorch, and 1.59 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
g-10-01:2356899:2357762 [3] NCCL INFO [Service thread] Connection closed by localRank 3
g-10-01:2356899:2356899 [3] NCCL INFO comm 0x9e8f6ea0 rank 3 nranks 12 cudaDev 3 busId e3000 - Abort COMPLETE
```
The dataset contains 12 `.json` files which are assembled and cached. Training can complete on any one of the 12 files. However, when assembled, there is the above out of memory error. If the files are re-arranged (ie `[2,0,1,3,4,5,6,7,8,9,10,11]`), the step on which training fails changes slightly. If training is restarted from a saved checkpoint using `resume_from_checkpoint = 'checkpoint_dir'`, training errors out of memory at exactly the same step.
Training of the same dataset using `accelerate` and FSDP completes without issue.
I am at a loss for what could be causing this.
### Expected behavior
The expected behavior is that training does not run out of memory at inconsistent times and completes a single epoch.