Fine-Tuning GPT-J CUDA Memory Error

Hello,

I have been trying to fine tune GPT-J for Causal Language Modelling using SageMaker parallelism and have been running into CUDA memory issues. I have reduced batch size, gradient accumulation, even using p3dn.24xlarge instance and still running into a memory issue as below:

UnexpectedStatusException: Error for Training job GPT2-v1-2023-02-13-16-30-08-935: Failed. Reason: AlgorithmError: ExecuteUserScriptError:
ExitCode 1
ErrorMessage "RuntimeError:: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 15.78 GiB total capacity; 13.34 GiB already allocated; 194.19 MiB free; 13.35 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
 --------------------------------------------------------------------------
 Primary job  terminated normally, but 1 process returned
 a non-zero exit code. Per user-direction, the job has been aborted.
 mpirun.real detected that one or more processes exited with non-zero status, thus causing
 the job to be terminated. The first process to do so was
 
 Process name: [[41130,1],0]
 Exit code:    1"
Command "mpirun --host algo-1:4 -np 4 --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include eth0 -mca oob_tcp_if_include eth0 -mca plm_rsh_no_tree_spawn 1 -bind-to none -map-by slot -mca pml o

I think the underlying issue is coming from my input dataset ( I am using my own) but I am struggling to understand what the warnings in the training log are saying. I have reduced the block_size but still having issues.

Here is my code below:

training_input_path = 's3://1111111111111-dev-datasets/opt/ml/input/fine_tuning_dataset_train_sample.txt'
test_input_path = 's3://1111111111111-dev-datasets/opt/ml/input/fine_tuning_dataset_eval_sample.txt'


hyperparameters = {
    'model_name_or_path':'EleutherAI/gpt-j-6B',
	'output_dir':'/opt/ml/model',
    'train_file' : '/opt/ml/input/data/training/fine_tuning_dataset_train_sample.txt',
    'validation_file': '/opt/ml/input/data/test/fine_tuning_dataset_eval_sample.txt',
    'do_train': True,
    'do_eval': True,
    'per_device_eval_batch_size':2,
    'per_device_train_batch_size':2,
    'gradient_accumulation_steps':1,
}

# configuration for running training on smdistributed Model Parallel
mpi_options = {
    "enabled" : True,
    "processes_per_host" : 4,
}
smp_options = {
    "enabled":True,
    "parameters": {
        "microbatches": 2,
        "placement_strategy": "spread",
        "pipeline": "interleaved",
        "optimize": "speed",
        "partitions": 2,
        "ddp": True,
    }
}

distribution={
    "smdistributed": {"modelparallel": smp_options},
    "mpi": mpi_options
}


# git configuration to download our fine-tuning script
git_config = {'repo': 'https://github.com/huggingface/transformers.git','branch': 'v4.17.0'}


# creates Hugging Face estimator
huggingface_estimator = HuggingFace(
	entry_point='run_clm.py',
	source_dir='./examples/pytorch/language-modeling',
	instance_type='ml.p3.8xlarge',
    git_config = git_config,
	instance_count=1,
    role=role,
	transformers_version='4.17.0',
	pytorch_version='1.10.2',
	py_version='py38',
	hyperparameters = hyperparameters,
    output_path = output_bucket,
    base_job_name = 'GPT2-v1', 
    distribution = distribution)


# starting the train job
huggingface_estimator.fit(inputs={'training':training_input_path,
                                 'test':test_input_path})

And here is the issues log, as you can see there are issue with hashing, py.arrow and block_size etc. It would be great to get some help understanding these warning to try and find out what the underlying issue is, the CUDA out of memory doesn’t seem to make too much sense.

Here is the training job log after the model has downloaded:

[1,mpirank:0,algo-1]<stderr>:[INFO|modeling_utils.py:1702] 2023-02-13 16:45:31,555 >> All model checkpoint weights were used when initializing GPTJForCausalLM.
[1,mpirank:0,algo-1]<stderr>:
[1,mpirank:0,algo-1]<stderr>:[INFO|modeling_utils.py:1710] 2023-02-13 16:45:31,555 >> All the weights of GPTJForCausalLM were initialized from the model checkpoint at EleutherAI/gpt-j-6B.
[1,mpirank:0,algo-1]<stderr>:If your task is similar to the task the model of the checkpoint was trained on, you can already use GPTJForCausalLM for predictions without further training.
[1,mpirank:0,algo-1]<stdout>:[INFO|modeling_utils.py:1702] 2023-02-13 16:45:31,555 >> All model checkpoint weights were used when initializing GPTJForCausalLM.
[1,mpirank:0,algo-1]<stdout>:
[1,mpirank:0,algo-1]<stdout>:[INFO|modeling_utils.py:1710] 2023-02-13 16:45:31,555 >> All the weights of GPTJForCausalLM were initialized from the model checkpoint at EleutherAI/gpt-j-6B.
[1,mpirank:0,algo-1]<stdout>:If your task is similar to the task the model of the checkpoint was trained on, you can already use GPTJForCausalLM for predictions without further training.
[1,mpirank:0,algo-1]<stderr>:WARNING:datasets.fingerprint:Parameter 'function'=<function main.<locals>.tokenize_function at 0x7f1f141cfb80> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
[1,mpirank:0,algo-1]<stderr>:#015Running tokenizer on dataset:   0%|          | 0/1 [00:00<?, ?ba/s]
[1,mpirank:0,algo-1]<stderr>:INFO:datasets.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-1c80317fa3b1799d.arrow
[1,mpirank:0,algo-1]<stderr>:#015Running tokenizer on dataset: 100%|██████████| 1/1 [00:00<00:00, 26.64ba/s]
[1,mpirank:0,algo-1]<stderr>:INFO:datasets.fingerprint:Parameter 'function'=<function main.<locals>.tokenize_function at 0x7f1ed9ee63a0> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead.
[1,mpirank:0,algo-1]<stderr>:#015Running tokenizer on dataset:   0%|          | 0/1 [00:00<?, ?ba/s]
[1,mpirank:0,algo-1]<stderr>:INFO:datasets.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-bdd640fb06671ad1.arrow
[1,mpirank:0,algo-1]<stderr>:#015Running tokenizer on dataset: 100%|██████████| 1/1 [00:00<00:00, 104.77ba/s]
[1,mpirank:0,algo-1]<stdout>:algo-1:80:80 [0] NCCL INFO Bootstrap : Using eth0:10.0.81.152<0>
[1,mpirank:0,algo-1]<stdout>:algo-1:80:80 [0] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v4 symbol.
[1,mpirank:0,algo-1]<stdout>:algo-1:80:80 [0] NCCL INFO NET/OFI Using aws-ofi-nccl 1.2.0aws
[1,mpirank:0,algo-1]<stdout>:algo-1:80:80 [0] NCCL INFO NET/OFI Setting FI_EFA_FORK_SAFE environment variable to 1
[1,mpirank:0,algo-1]<stdout>:
[1,mpirank:0,algo-1]<stdout>:algo-1:80:80 [0] ofi_init:1157 NCCL WARN NET/OFI Only EFA provider is supported
[1,mpirank:0,algo-1]<stdout>:
[1,mpirank:0,algo-1]<stdout>:algo-1:80:80 [0] ofi_init:1208 NCCL WARN NET/OFI aws-ofi-nccl initialization failed
[1,mpirank:0,algo-1]<stdout>:algo-1:80:80 [0] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
[1,mpirank:0,algo-1]<stdout>:algo-1:80:80 [0] NCCL INFO NET/Socket : Using [0]eth0:10.0.81.152<0>
[1,mpirank:0,algo-1]<stdout>:algo-1:80:80 [0] NCCL INFO Using network Socket
[1,mpirank:0,algo-1]<stdout>:NCCL version 2.10.3+cuda11.3
[1,mpirank:1,algo-1]<stdout>:algo-1:81:81 [1] NCCL INFO Bootstrap : Using eth0:10.0.81.152<0>
[1,mpirank:1,algo-1]<stdout>:algo-1:81:81 [1] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v4 symbol.
[1,mpirank:1,algo-1]<stdout>:algo-1:81:81 [1] NCCL INFO NET/OFI Using aws-ofi-nccl 1.2.0aws
[1,mpirank:1,algo-1]<stdout>:algo-1:81:81 [1] NCCL INFO NET/OFI Setting FI_EFA_FORK_SAFE environment variable to 1
[1,mpirank:1,algo-1]<stdout>:
[1,mpirank:1,algo-1]<stdout>:algo-1:81:81 [1] ofi_init:1157 NCCL WARN NET/OFI Only EFA provider is supported
[1,mpirank:1,algo-1]<stdout>:
[1,mpirank:1,algo-1]<stdout>:algo-1:81:81 [1] ofi_init:1208 NCCL WARN NET/OFI aws-ofi-nccl initialization failed
[1,mpirank:1,algo-1]<stdout>:algo-1:81:81 [1] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
[1,mpirank:1,algo-1]<stdout>:algo-1:81:81 [1] NCCL INFO NET/Socket : Using [0]eth0:10.0.81.152<0>
[1,mpirank:1,algo-1]<stdout>:algo-1:81:81 [1] NCCL INFO Using network Socket
[1,mpirank:2,algo-1]<stdout>:algo-1:82:82 [2] NCCL INFO Bootstrap : Using eth0:10.0.81.152<0>
[1,mpirank:2,algo-1]<stdout>:algo-1:82:82 [2] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v4 symbol.
[1,mpirank:2,algo-1]<stdout>:algo-1:82:82 [2] NCCL INFO NET/OFI Using aws-ofi-nccl 1.2.0aws
[1,mpirank:2,algo-1]<stdout>:algo-1:82:82 [2] NCCL INFO NET/OFI Setting FI_EFA_FORK_SAFE environment variable to 1
[1,mpirank:2,algo-1]<stdout>:
[1,mpirank:2,algo-1]<stdout>:algo-1:82:82 [2] ofi_init:1157 NCCL WARN NET/OFI Only EFA provider is supported
[1,mpirank:2,algo-1]<stdout>:
[1,mpirank:2,algo-1]<stdout>:algo-1:82:82 [2] ofi_init:1208 NCCL WARN NET/OFI aws-ofi-nccl initialization failed
[1,mpirank:2,algo-1]<stdout>:algo-1:82:82 [2] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
[1,mpirank:2,algo-1]<stdout>:algo-1:82:82 [2] NCCL INFO NET/Socket : Using [0]eth0:10.0.81.152<0>
[1,mpirank:2,algo-1]<stdout>:algo-1:82:82 [2] NCCL INFO Using network Socket
[1,mpirank:3,algo-1]<stdout>:algo-1:83:83 [3] NCCL INFO Bootstrap : Using eth0:10.0.81.152<0>
[1,mpirank:3,algo-1]<stdout>:algo-1:83:83 [3] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v4 symbol.
[1,mpirank:3,algo-1]<stdout>:algo-1:83:83 [3] NCCL INFO NET/OFI Using aws-ofi-nccl 1.2.0aws
[1,mpirank:3,algo-1]<stdout>:algo-1:83:83 [3] NCCL INFO NET/OFI Setting FI_EFA_FORK_SAFE environment variable to 1
[1,mpirank:3,algo-1]<stdout>:
[1,mpirank:3,algo-1]<stdout>:algo-1:83:83 [3] ofi_init:1157 NCCL WARN NET/OFI Only EFA provider is supported
[1,mpirank:3,algo-1]<stdout>:
[1,mpirank:3,algo-1]<stdout>:algo-1:83:83 [3] ofi_init:1208 NCCL WARN NET/OFI aws-ofi-nccl initialization failed
[1,mpirank:3,algo-1]<stdout>:algo-1:83:83 [3] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
[1,mpirank:3,algo-1]<stdout>:algo-1:83:83 [3] NCCL INFO NET/Socket : Using [0]eth0:10.0.81.152<0>
[1,mpirank:3,algo-1]<stdout>:algo-1:83:83 [3] NCCL INFO Using network Socket
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO NCCL_MIN_NRINGS set by environment to 4.
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO NCCL_MIN_NRINGS set by environment to 4.
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 00/08 :    0   1   2   3
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 01/08 :    0   3   2   1
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO NCCL_MIN_NRINGS set by environment to 4.
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Trees [0] -1/-1/-1->1->2 [1] 2/-1/-1->1->-1 [2] -1/-1/-1->1->2 [3] 2/-1/-1->1->-1 [4] -1/-1/-1->1->2 [5] 2/-1/-1->1->-1 [6] -1/-1/-1->1->2 [7] 2/-1/-1->1->-1
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO NCCL_MIN_NRINGS set by environment to 4.
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Trees [0] 2/-1/-1->3->0 [1] 0/-1/-1->3->2 [2] 2/-1/-1->3->0 [3] 0/-1/-1->3->2 [4] 2/-1/-1->3->0 [5] 0/-1/-1->3->2 [6] 2/-1/-1->3->0 [7] 0/-1/-1->3->2
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Trees [0] 1/-1/-1->2->3 [1] 3/-1/-1->2->1 [2] 1/-1/-1->2->3 [3] 3/-1/-1->2->1 [4] 1/-1/-1->2->3 [5] 3/-1/-1->2->1 [6] 1/-1/-1->2->3 [7] 3/-1/-1->2->1
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 02/08 :    0   3   1   2
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 03/08 :    0   2   1   3
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 04/08 :    0   1   2   3
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 05/08 :    0   3   2   1
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 06/08 :    0   3   1   2
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 07/08 :    0   2   1   3
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Trees [0] 3/-1/-1->0->-1 [1] -1/-1/-1->0->3 [2] 3/-1/-1->0->-1 [3] -1/-1/-1->0->3 [4] 3/-1/-1->0->-1 [5] -1/-1/-1->0->3 [6] 3/-1/-1->0->-1 [7] -1/-1/-1->0->3
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 00 : 1[1c0] -> 2[1d0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 02 : 1[1c0] -> 2[1d0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 04 : 1[1c0] -> 2[1d0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 06 : 1[1c0] -> 2[1d0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 00 : 0[1b0] -> 1[1c0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 00 : 3[1e0] -> 0[1b0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 04 : 0[1b0] -> 1[1c0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 03 : 3[1e0] -> 0[1b0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 00 : 2[1d0] -> 3[1e0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 04 : 3[1e0] -> 0[1b0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 04 : 2[1d0] -> 3[1e0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 07 : 3[1e0] -> 0[1b0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 03 : 0[1b0] -> 2[1d0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 02 : 2[1d0] -> 0[1b0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 03 : 1[1c0] -> 3[1e0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 02 : 3[1e0] -> 1[1c0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 06 : 2[1d0] -> 0[1b0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 07 : 0[1b0] -> 2[1d0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 07 : 1[1c0] -> 3[1e0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 06 : 3[1e0] -> 1[1c0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 01 : 0[1b0] -> 3[1e0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 01 : 2[1d0] -> 1[1c0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 02 : 0[1b0] -> 3[1e0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 03 : 2[1d0] -> 1[1c0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 01 : 1[1c0] -> 0[1b0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 01 : 3[1e0] -> 2[1d0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 05 : 0[1b0] -> 3[1e0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 05 : 2[1d0] -> 1[1c0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 06 : 0[1b0] -> 3[1e0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 05 : 1[1c0] -> 0[1b0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 05 : 3[1e0] -> 2[1d0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 07 : 2[1d0] -> 1[1c0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Connected all rings
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Connected all rings
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Connected all rings
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Connected all rings
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 01 : 1[1c0] -> 2[1d0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 03 : 1[1c0] -> 2[1d0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 05 : 1[1c0] -> 2[1d0] via P2P/IPC
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Channel 07 : 1[1c0] -> 2[1d0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 01 : 2[1d0] -> 3[1e0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 02 : 2[1d0] -> 3[1e0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 03 : 2[1d0] -> 3[1e0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 01 : 3[1e0] -> 0[1b0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 05 : 2[1d0] -> 3[1e0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 02 : 3[1e0] -> 0[1b0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 06 : 2[1d0] -> 3[1e0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 05 : 3[1e0] -> 0[1b0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 07 : 2[1d0] -> 3[1e0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 06 : 3[1e0] -> 0[1b0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 00 : 0[1b0] -> 3[1e0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 03 : 0[1b0] -> 3[1e0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 04 : 0[1b0] -> 3[1e0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Channel 07 : 0[1b0] -> 3[1e0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 00 : 3[1e0] -> 2[1d0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 02 : 3[1e0] -> 2[1d0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 00 : 2[1d0] -> 1[1c0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 03 : 3[1e0] -> 2[1d0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 02 : 2[1d0] -> 1[1c0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 04 : 3[1e0] -> 2[1d0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 06 : 3[1e0] -> 2[1d0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 04 : 2[1d0] -> 1[1c0] via P2P/IPC
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Channel 07 : 3[1e0] -> 2[1d0] via P2P/IPC
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Channel 06 : 2[1d0] -> 1[1c0] via P2P/IPC
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO Connected all trees
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/512
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO 8 coll channels, 8 p2p channels, 2 p2p channels per peer
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO Connected all trees
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/512
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO 8 coll channels, 8 p2p channels, 2 p2p channels per peer
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO Connected all trees
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/512
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO 8 coll channels, 8 p2p channels, 2 p2p channels per peer
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO Connected all trees
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/512
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO 8 coll channels, 8 p2p channels, 2 p2p channels per peer
[1,mpirank:0,algo-1]<stdout>:algo-1:80:586 [0] NCCL INFO comm 0x7f1178003010 rank 0 nranks 4 cudaDev 0 busId 1b0 - Init COMPLETE
[1,mpirank:2,algo-1]<stdout>:algo-1:82:588 [2] NCCL INFO comm 0x7f6a9c003040 rank 2 nranks 4 cudaDev 2 busId 1d0 - Init COMPLETE
[1,mpirank:1,algo-1]<stdout>:algo-1:81:587 [1] NCCL INFO comm 0x7f9678003260 rank 1 nranks 4 cudaDev 1 busId 1c0 - Init COMPLETE
[1,mpirank:3,algo-1]<stdout>:algo-1:83:589 [3] NCCL INFO comm 0x7f58d0003040 rank 3 nranks 4 cudaDev 3 busId 1e0 - Init COMPLETE
[1,mpirank:0,algo-1]<stdout>:algo-1:80:80 [0] NCCL INFO Launch mode Parallel
[1,mpirank:0,algo-1]<stderr>:WARNING:__main__:The tokenizer picked seems to have a very large `model_max_length` (2048). Picking 1024 instead. You can change that default value by passing --block_size xxx.
[1,mpirank:1,algo-1]<stderr>:WARNING:datasets.fingerprint:Parameter 'function'=<function main.<locals>.tokenize_function at 0x7f97d4202e50> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
[1,mpirank:1,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-1c80317fa3b1799d.arrow
[1,mpirank:3,algo-1]<stderr>:WARNING:datasets.fingerprint:Parameter 'function'=<function main.<locals>.tokenize_function at 0x7f5a3c55e9d0> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
[1,mpirank:2,algo-1]<stderr>:WARNING:datasets.fingerprint:Parameter 'function'=<function main.<locals>.tokenize_function at 0x7f6c00259b80> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.
[1,mpirank:3,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-1c80317fa3b1799d.arrow
[1,mpirank:2,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-1c80317fa3b1799d.arrow
[1,mpirank:0,algo-1]<stderr>:INFO:datasets.fingerprint:Parameter 'function'=<function main.<locals>.group_texts at 0x7f1ed9ee64c0> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead.
[1,mpirank:0,algo-1]<stderr>:#015Grouping texts in chunks of 1024:   0%|          | 0/1 [00:00<?, ?ba/s]
[1,mpirank:1,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-bdd640fb06671ad1.arrow
[1,mpirank:2,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-bdd640fb06671ad1.arrow
[1,mpirank:3,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-bdd640fb06671ad1.arrow
[1,mpirank:1,algo-1]<stderr>:WARNING:__main__:The tokenizer picked seems to have a very large `model_max_length` (2048). Picking 1024 instead. You can change that default value by passing --block_size xxx.
[1,mpirank:2,algo-1]<stderr>:WARNING:__main__:The tokenizer picked seems to have a very large `model_max_length` (2048). Picking 1024 instead. You can change that default value by passing --block_size xxx.
[1,mpirank:3,algo-1]<stderr>:WARNING:__main__:The tokenizer picked seems to have a very large `model_max_length` (2048). Picking 1024 instead. You can change that default value by passing --block_size xxx.
[1,mpirank:0,algo-1]<stderr>:INFO:datasets.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-3eb13b9046685257.arrow
[1,mpirank:0,algo-1]<stderr>:#015Grouping texts in chunks of 1024: 100%|██████████| 1/1 [00:00<00:00, 41.75ba/s]
[1,mpirank:0,algo-1]<stderr>:INFO:datasets.fingerprint:Parameter 'function'=<function main.<locals>.group_texts at 0x7f1ed9ee6280> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead.
[1,mpirank:0,algo-1]<stderr>:#015Grouping texts in chunks of 1024:   0%|          | 0/1 [00:00<?, ?ba/s]
[1,mpirank:0,algo-1]<stderr>:INFO:datasets.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-23b8c1e9392456de.arrow
[1,mpirank:0,algo-1]<stderr>:#015Grouping texts in chunks of 1024: 100%|██████████| 1/1 [00:00<00:00, 89.02ba/s]
[1,mpirank:1,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-3eb13b9046685257.arrow
[1,mpirank:2,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-3eb13b9046685257.arrow
[1,mpirank:3,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-3eb13b9046685257.arrow
[1,mpirank:1,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-23b8c1e9392456de.arrow
[1,mpirank:2,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-23b8c1e9392456de.arrow
[1,mpirank:3,algo-1]<stderr>:WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-96d8c801096eb600/0.0.0/08f6fb1dd2dab0a18ea441c359e1d63794ea8cb53e7863e6edf8fc5655e47ec4/cache-23b8c1e9392456de.arrow
[1,mpirank:2,algo-1]<stderr>:#015Downloading:   0%|          | 0.00/1.41k [00:00<?, ?B/s]
[1,mpirank:2,algo-1]<stderr>:#015Downloading: 3.19kB [00:00, 2.80MB/s]
[1,mpirank:2,algo-1]<stderr>:/opt/conda/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
[1,mpirank:2,algo-1]<stderr>:  warnings.warn(
[1,mpirank:3,algo-1]<stderr>:/opt/conda/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
[1,mpirank:3,algo-1]<stderr>:  warnings.warn(
[1,mpirank:1,algo-1]<stderr>:/opt/conda/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
[1,mpirank:1,algo-1]<stderr>:  warnings.warn(
[1,mpirank:0,algo-1]<stderr>:/opt/conda/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning