12% into epoch training loss drops to 0.0

As shown in the logs at the end of this post, the training loss steadily decreases until about 12%, and then it drops to 0.0, and this continued for another 12% without ever increasing.

Please let me know if there’s any other information I can provide aside from the below, and I really appreciate anybody’s insight into this issue, as this is a difficult problem to debug because the issue arose after over a day of training (250K inputs), so it’s expensive and time consuming to troubleshoot.

The Environment
We are using an Azure VM with 320 GB VRAM, Standard_ND96_asr_v4.

Here are the libraries:

# pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.0-py3-none-win_amd64.whl
# pip install transformers
# pip install peft
# pip install accelerate
# pip install datasets
# pip install loralib
# pip install einops
# pip install scipy
# pip install scikit-learn

The Model
We are using mistralai/Mistral-7B-v0.1.

Here is how we are loading the model:

bits_and_bytes = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
    'mistralai/Mistral-7B-v0.1',
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bits_and_bytes,
    cache_dir=hugging_face_cache_dir
)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, get_lora_config(base_model))
model.config.use_cache = False

The Data

  • We are trying to create a chatbot trained using online forum discussions.
  • We have about three million training inputs in CSV format.
  • Each training input contains approximately 575 tokens.
  • Each training input uses the incomplete column for input_ids and the complete column for labels.
  • The incomplete data contains a metadata structure with labels for Topic and Users followed by a chat history, typically including only a single message, but sometimes multiple messages as shown below, and always ends with an open-ended response label.
  • The complete column contains the response in addition to a Tag label.
  • There are many reasons we are including this extra metadata, which I’d be really interested in hearing your thoughts about. I don’t suspect it is related to the training loss issue, but perhaps the repeated patterns are causing an overfitting issue?
incomplete,complete
"Topic: Example Topic 1\nUsers: drbob, drjim\ndrbob: This is an example chat message.\ndrjim: ","This is an example response.\nTag: Example 1"
"Topic: Example Topic 2\nUsers: drdan, drjil\ndrdan: This is an example chat message.\ndrjil: This is an example chat response.\ndrdan: ","This is an example response for a chat with more than one user in the incomplete field.\nTag: Example 2"
...

The idea is obviously that we provide the LLM with a prompt string like the incomplete field, with an open-ended response label, and then the LLM will complete the response along with a tag that we can use to label the conversation as well as search for related topics to share with the user.

Here is how we are loading and tokenizing the data:

tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=hugging_face_cache_dir)
tokenizer.pad_token = tokenizer.eos_token

train_files = ["""list of CSV files"""]
dataset = load_dataset('csv', data_files=train_files, cache_dir=hugging_face_cache_dir)

def tokenize_function(examples):
    inputs = [text.replace("\\n", "\n") for text in examples['incomplete']]  # preprocessing involves replacing escaped new line chars with actual newline chars
    labels = [text.replace("\\n", "\n") for text in examples['complete']]  # preprocessing involves replacing escaped new line chars with actual newline chars
    tokenized_inputs = tokenizer(inputs, padding='max_length', max_length=MAX_TOKEN_LENGTH, truncation=True, return_tensors='pt')
    tokenized_labels = tokenizer(labels, padding='max_length', max_length=MAX_TOKEN_LENGTH, truncation=True, return_tensors='pt')
    return {
        'input_ids': tokenized_inputs['input_ids'],
        'attention_mask': tokenized_inputs['attention_mask'],
        'labels': tokenized_labels['input_ids']
    }
tokenized_dataset = dataset.map(tokenize_function, batched=True)

The Training

  • Because of resource limitations we are only running a single epoch.
  • We are not including a validation split, because the validation only ever occurs on a single GPU, which causes an OOM issue. That was when we were using a lower-end VM, so perhaps now that we are using a VM with more VRAM we should include this, but I am apprehensive to consider this because it would both slow down training and make it so we could only tests on our higher-end VM. However if not including a validation split is the culprit of the issue then obviously we would reimplement it.

Here are the training arguments:

training_args = TrainingArguments(
    per_device_train_batch_size=144, # I used nvidia-smi command to determine the values 144 and 16 to minimize the training time, which puts us at ~97% VRAM utilization
    gradient_accumulation_steps=16,
    num_train_epochs=1,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=1,
    output_dir='/output/path',
    overwrite_output_dir=True,
    save_strategy='epoch',
    save_total_limit=3,
    optim='paged_adamw_8bit',
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    remove_unused_columns=True
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint). # resume_from_checkpoint == False, but I am including the option here in case it is somehow related to the issue

The Log
Here is the console output:

You are using 8-bit optimizers with a version of `bitsandbytes` < 0.41.1. It is recommended to update your version as a major bug has been fixed in 8-bit optimizers.
  0%|          | 0/1240 [00:00<?, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  0%|          | 1/1240 [08:49<182:23:18, 529.94s/it]{'loss': 3.0828, 'learning_rate': 3.225806451612903e-06, 'epoch': 0.0}
  0%|          | 2/1240 [17:27<179:38:17, 522.37s/it]{'loss': 3.1038, 'learning_rate': 6.451612903225806e-06, 'epoch': 0.0}
  0%|          | 3/1240 [26:04<178:48:09, 520.36s/it]{'loss': 3.0946, 'learning_rate': 9.67741935483871e-06, 'epoch': 0.0}
  0%|          | 4/1240 [34:42<178:17:23, 519.29s/it]{'loss': 3.1097, 'learning_rate': 1.2903225806451613e-05, 'epoch': 0.0}
  0%|          | 5/1240 [43:20<177:59:03, 518.82s/it]{'loss': 3.0952, 'learning_rate': 1.6129032258064517e-05, 'epoch': 0.0}
  0%|          | 6/1240 [51:58<177:41:56, 518.41s/it]{'loss': 3.0645, 'learning_rate': 1.935483870967742e-05, 'epoch': 0.0}
  1%|          | 7/1240 [1:00:35<177:28:52, 518.19s/it]{'loss': 3.0548, 'learning_rate': 2.258064516129032e-05, 'epoch': 0.01}
  1%|          | 8/1240 [1:09:13<177:17:17, 518.05s/it]{'loss': 3.0167, 'learning_rate': 2.5806451612903226e-05, 'epoch': 0.01}
  1%|          | 9/1240 [1:17:51<177:08:46, 518.06s/it]{'loss': 3.0081, 'learning_rate': 2.9032258064516133e-05, 'epoch': 0.01}
  1%|          | 10/1240 [1:26:29<176:59:09, 518.01s/it]{'loss': 2.9889, 'learning_rate': 3.2258064516129034e-05, 'epoch': 0.01}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  1%|          | 11/1240 [1:35:08<176:53:58, 518.18s/it]{'loss': 2.9646, 'learning_rate': 3.548387096774194e-05, 'epoch': 0.01}
  1%|          | 12/1240 [1:43:46<176:43:31, 518.09s/it]{'loss': 2.9332, 'learning_rate': 3.870967741935484e-05, 'epoch': 0.01}
  1%|          | 13/1240 [1:52:23<176:32:23, 517.97s/it]{'loss': 2.8905, 'learning_rate': 4.1935483870967746e-05, 'epoch': 0.01}
  1%|          | 14/1240 [2:01:01<176:22:26, 517.90s/it]{'loss': 2.8733, 'learning_rate': 4.516129032258064e-05, 'epoch': 0.01}
  1%|          | 15/1240 [2:09:39<176:12:59, 517.86s/it]{'loss': 2.8366, 'learning_rate': 4.8387096774193554e-05, 'epoch': 0.01}
  1%|▏         | 16/1240 [2:18:16<176:00:08, 517.65s/it]{'loss': 2.8278, 'learning_rate': 5.161290322580645e-05, 'epoch': 0.01}
  1%|▏         | 17/1240 [2:26:53<175:49:41, 517.56s/it]{'loss': 2.793, 'learning_rate': 5.4838709677419355e-05, 'epoch': 0.01}
  1%|▏         | 18/1240 [2:35:31<175:41:55, 517.61s/it]{'loss': 2.7755, 'learning_rate': 5.8064516129032266e-05, 'epoch': 0.01}
  2%|▏         | 19/1240 [2:44:09<175:32:47, 517.58s/it]{'loss': 2.7646, 'learning_rate': 6.129032258064517e-05, 'epoch': 0.02}
  2%|▏         | 20/1240 [2:52:46<175:21:25, 517.45s/it]{'loss': 2.7564, 'learning_rate': 6.451612903225807e-05, 'epoch': 0.02}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  2%|▏         | 21/1240 [3:01:23<175:11:01, 517.36s/it]{'loss': 2.7248, 'learning_rate': 6.774193548387096e-05, 'epoch': 0.02}
  2%|▏         | 22/1240 [3:10:00<174:59:53, 517.24s/it]{'loss': 2.7041, 'learning_rate': 7.096774193548388e-05, 'epoch': 0.02}
  2%|▏         | 23/1240 [3:18:37<174:48:37, 517.11s/it]{'loss': 2.6881, 'learning_rate': 7.419354838709677e-05, 'epoch': 0.02}
  2%|▏         | 24/1240 [3:27:13<174:37:05, 516.96s/it]{'loss': 2.6689, 'learning_rate': 7.741935483870968e-05, 'epoch': 0.02}
  2%|▏         | 25/1240 [3:35:49<174:23:11, 516.70s/it]{'loss': 2.6351, 'learning_rate': 8.064516129032258e-05, 'epoch': 0.02}
  2%|▏         | 26/1240 [3:44:25<174:07:02, 516.33s/it]{'loss': 2.6349, 'learning_rate': 8.387096774193549e-05, 'epoch': 0.02}
  2%|▏         | 27/1240 [3:52:59<173:48:21, 515.83s/it]{'loss': 2.6195, 'learning_rate': 8.709677419354839e-05, 'epoch': 0.02}
  2%|▏         | 28/1240 [4:01:33<173:28:14, 515.26s/it]{'loss': 2.6214, 'learning_rate': 9.032258064516129e-05, 'epoch': 0.02}
  2%|▏         | 29/1240 [4:10:08<173:12:51, 514.92s/it]{'loss': 2.6277, 'learning_rate': 9.35483870967742e-05, 'epoch': 0.02}
  2%|▏         | 30/1240 [4:18:44<173:11:24, 515.28s/it]{'loss': 2.6085, 'learning_rate': 9.677419354838711e-05, 'epoch': 0.02}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  2%|▎         | 31/1240 [4:27:21<173:18:05, 516.03s/it]{'loss': 2.6072, 'learning_rate': 0.0001, 'epoch': 0.02}
  3%|▎         | 32/1240 [4:35:56<172:57:32, 515.44s/it]{'loss': 2.5658, 'learning_rate': 0.0001032258064516129, 'epoch': 0.03}
  3%|▎         | 33/1240 [4:44:29<172:37:31, 514.87s/it]{'loss': 2.5862, 'learning_rate': 0.0001064516129032258, 'epoch': 0.03}
  3%|▎         | 34/1240 [4:53:02<172:19:02, 514.38s/it]{'loss': 2.5895, 'learning_rate': 0.00010967741935483871, 'epoch': 0.03}
  3%|▎         | 35/1240 [5:01:36<172:04:48, 514.10s/it]{'loss': 2.5775, 'learning_rate': 0.00011290322580645163, 'epoch': 0.03}
  3%|▎         | 36/1240 [5:10:09<171:53:02, 513.94s/it]{'loss': 2.5797, 'learning_rate': 0.00011612903225806453, 'epoch': 0.03}
  3%|▎         | 37/1240 [5:18:43<171:41:28, 513.79s/it]{'loss': 2.5633, 'learning_rate': 0.00011935483870967743, 'epoch': 0.03}
  3%|▎         | 38/1240 [5:27:17<171:33:10, 513.80s/it]{'loss': 2.5689, 'learning_rate': 0.00012258064516129034, 'epoch': 0.03}
  3%|▎         | 39/1240 [5:35:51<171:26:03, 513.87s/it]{'loss': 2.5667, 'learning_rate': 0.00012580645161290322, 'epoch': 0.03}
  3%|▎         | 40/1240 [5:44:25<171:17:58, 513.90s/it]{'loss': 2.553, 'learning_rate': 0.00012903225806451613, 'epoch': 0.03}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  3%|▎         | 41/1240 [5:52:59<171:11:26, 514.00s/it]{'loss': 2.5391, 'learning_rate': 0.00013225806451612905, 'epoch': 0.03}
  3%|▎         | 42/1240 [6:01:33<171:06:46, 514.20s/it]{'loss': 2.5405, 'learning_rate': 0.00013548387096774193, 'epoch': 0.03}
  3%|▎         | 43/1240 [6:10:08<171:00:00, 514.29s/it]{'loss': 2.5313, 'learning_rate': 0.00013870967741935487, 'epoch': 0.03}
  4%|▎         | 44/1240 [6:18:42<170:49:36, 514.19s/it]{'loss': 2.5308, 'learning_rate': 0.00014193548387096775, 'epoch': 0.04}
  4%|▎         | 45/1240 [6:27:16<170:39:45, 514.13s/it]{'loss': 2.5286, 'learning_rate': 0.00014516129032258066, 'epoch': 0.04}
  4%|▎         | 46/1240 [6:35:50<170:28:59, 514.02s/it]{'loss': 2.5126, 'learning_rate': 0.00014838709677419355, 'epoch': 0.04}
  4%|▍         | 47/1240 [6:44:23<170:18:33, 513.93s/it]{'loss': 2.5034, 'learning_rate': 0.00015161290322580646, 'epoch': 0.04}
  4%|▍         | 48/1240 [6:52:57<170:07:56, 513.82s/it]{'loss': 2.5006, 'learning_rate': 0.00015483870967741937, 'epoch': 0.04}
  4%|▍         | 49/1240 [7:01:30<169:57:20, 513.72s/it]{'loss': 2.508, 'learning_rate': 0.00015806451612903225, 'epoch': 0.04}
  4%|▍         | 50/1240 [7:10:04<169:46:01, 513.58s/it]{'loss': 2.4949, 'learning_rate': 0.00016129032258064516, 'epoch': 0.04}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  4%|▍         | 51/1240 [7:18:37<169:38:00, 513.61s/it]{'loss': 2.4782, 'learning_rate': 0.00016451612903225807, 'epoch': 0.04}
  4%|▍         | 52/1240 [7:27:11<169:31:03, 513.69s/it]{'loss': 2.4743, 'learning_rate': 0.00016774193548387098, 'epoch': 0.04}
  4%|▍         | 53/1240 [7:35:46<169:30:37, 514.10s/it]{'loss': 2.4525, 'learning_rate': 0.0001709677419354839, 'epoch': 0.04}
  4%|▍         | 54/1240 [7:44:23<169:35:20, 514.77s/it]{'loss': 2.4525, 'learning_rate': 0.00017419354838709678, 'epoch': 0.04}
  4%|▍         | 55/1240 [7:53:00<169:41:36, 515.52s/it]{'loss': 2.4447, 'learning_rate': 0.0001774193548387097, 'epoch': 0.04}
  5%|▍         | 56/1240 [8:01:34<169:27:15, 515.23s/it]{'loss': 2.4343, 'learning_rate': 0.00018064516129032257, 'epoch': 0.05}
  5%|▍         | 57/1240 [8:10:08<169:10:00, 514.79s/it]{'loss': 2.4542, 'learning_rate': 0.00018387096774193548, 'epoch': 0.05}
  5%|▍         | 58/1240 [8:18:42<168:54:10, 514.42s/it]{'loss': 2.4198, 'learning_rate': 0.0001870967741935484, 'epoch': 0.05}
  5%|▍         | 59/1240 [8:27:15<168:40:08, 514.15s/it]{'loss': 2.4259, 'learning_rate': 0.0001903225806451613, 'epoch': 0.05}
  5%|▍         | 60/1240 [8:35:49<168:26:41, 513.90s/it]{'loss': 2.4385, 'learning_rate': 0.00019354838709677422, 'epoch': 0.05}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  5%|▍         | 61/1240 [8:44:22<168:16:56, 513.84s/it]{'loss': 2.4259, 'learning_rate': 0.0001967741935483871, 'epoch': 0.05}
  5%|▌         | 62/1240 [8:52:56<168:06:22, 513.74s/it]{'loss': 2.4208, 'learning_rate': 0.0002, 'epoch': 0.05}
  5%|▌         | 63/1240 [9:01:29<167:56:31, 513.67s/it]{'loss': 2.4049, 'learning_rate': 0.00019999964438594984, 'epoch': 0.05}
  5%|▌         | 64/1240 [9:10:03<167:45:30, 513.55s/it]{'loss': 2.4048, 'learning_rate': 0.00019999857754632864, 'epoch': 0.05}
  5%|▌         | 65/1240 [9:18:36<167:37:31, 513.58s/it]{'loss': 2.4, 'learning_rate': 0.00019999679948872395, 'epoch': 0.05}
  5%|▌         | 66/1240 [9:27:10<167:27:48, 513.52s/it]{'loss': 2.3831, 'learning_rate': 0.00019999431022578194, 'epoch': 0.05}
  5%|▌         | 67/1240 [9:35:43<167:20:04, 513.56s/it]{'loss': 2.4004, 'learning_rate': 0.00019999110977520689, 'epoch': 0.05}
  5%|▌         | 68/1240 [9:44:17<167:12:11, 513.59s/it]{'loss': 2.3796, 'learning_rate': 0.0001999871981597613, 'epoch': 0.05}
  6%|▌         | 69/1240 [9:52:50<167:03:15, 513.57s/it]{'loss': 2.3723, 'learning_rate': 0.00019998257540726567, 'epoch': 0.06}
  6%|▌         | 70/1240 [10:01:24<166:53:20, 513.51s/it]{'loss': 2.3838, 'learning_rate': 0.00019997724155059837, 'epoch': 0.06}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  6%|▌         | 71/1240 [10:09:58<166:45:58, 513.57s/it]{'loss': 2.3887, 'learning_rate': 0.00019997119662769523, 'epoch': 0.06}
  6%|▌         | 72/1240 [10:18:31<166:37:21, 513.56s/it]{'loss': 2.3679, 'learning_rate': 0.00019996444068154948, 'epoch': 0.06}
  6%|▌         | 73/1240 [10:27:04<166:27:38, 513.50s/it]{'loss': 2.3573, 'learning_rate': 0.00019995697376021124, 'epoch': 0.06}
<removed to fit maximum post requirements>
  7%|▋         | 88/1240 [12:35:26<164:19:18, 513.51s/it]{'loss': 3.2517, 'learning_rate': 0.00019979523610431249, 'epoch': 0.07}
  7%|▋         | 89/1240 [12:44:00<164:09:31, 513.44s/it]{'loss': 3.259, 'learning_rate': 0.0001997778234064323, 'epoch': 0.07}
  7%|▋         | 90/1240 [12:52:33<164:01:08, 513.45s/it]{'loss': 2.7538, 'learning_rate': 0.0001997778234064323, 'epoch': 0.07}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  7%|▋         | 91/1240 [13:01:07<163:53:41, 513.51s/it]{'loss': 2.778, 'learning_rate': 0.00019975970106063414, 'epoch': 0.07}
  7%|▋         | 92/1240 [13:09:40<163:43:57, 513.45s/it]{'loss': 2.5826, 'learning_rate': 0.00019974086919580928, 'epoch': 0.07}
  8%|▊         | 93/1240 [13:18:13<163:35:13, 513.44s/it]{'loss': 2.4453, 'learning_rate': 0.00019972132794589517, 'epoch': 0.07}
  8%|▊         | 94/1240 [13:26:47<163:25:22, 513.37s/it]{'loss': 2.4104, 'learning_rate': 0.00019970107744987474, 'epoch': 0.08}
  8%|▊         | 95/1240 [13:35:20<163:14:50, 513.27s/it]{'loss': 2.3659, 'learning_rate': 0.00019968011785177515, 'epoch': 0.08}
  8%|▊         | 96/1240 [13:43:53<163:07:49, 513.35s/it]{'loss': 2.3684, 'learning_rate': 0.000199658449300667, 'epoch': 0.08}
  8%|▊         | 97/1240 [13:52:27<163:00:19, 513.40s/it]{'loss': 2.3619, 'learning_rate': 0.0001996360719506631, 'epoch': 0.08}
  8%|▊         | 98/1240 [14:01:01<162:53:56, 513.52s/it]{'loss': 2.3464, 'learning_rate': 0.0001996129859609174, 'epoch': 0.08}
  8%|▊         | 99/1240 [14:09:34<162:44:07, 513.45s/it]{'loss': 2.3586, 'learning_rate': 0.00019958919149562403, 'epoch': 0.08}
  8%|▊         | 100/1240 [14:18:07<162:34:48, 513.41s/it]{'loss': 2.3405, 'learning_rate': 0.00019956468872401586, 'epoch': 0.08}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  8%|▊         | 101/1240 [14:26:41<162:27:09, 513.46s/it]{'loss': 2.3344, 'learning_rate': 0.0001995394778203635, 'epoch': 0.08}
  8%|▊         | 102/1240 [14:35:14<162:17:36, 513.41s/it]{'loss': 2.332, 'learning_rate': 0.00019951355896397402, 'epoch': 0.08}
  8%|▊         | 103/1240 [14:43:47<162:08:27, 513.37s/it]{'loss': 2.3241, 'learning_rate': 0.00019948693233918952, 'epoch': 0.08}
  8%|▊         | 104/1240 [14:52:21<161:59:20, 513.35s/it]{'loss': 2.3315, 'learning_rate': 0.00019945959813538614, 'epoch': 0.08}
  8%|▊         | 105/1240 [15:00:54<161:50:21, 513.32s/it]{'loss': 2.3221, 'learning_rate': 0.0001994315565469723, 'epoch': 0.08}
  9%|▊         | 106/1240 [15:09:27<161:39:41, 513.21s/it]{'loss': 2.3287, 'learning_rate': 0.00019940280777338778, 'epoch': 0.09}
  9%|▊         | 107/1240 [15:18:00<161:32:29, 513.28s/it]{'loss': 2.3142, 'learning_rate': 0.00019937335201910186, 'epoch': 0.09}
  9%|▊         | 108/1240 [15:26:34<161:25:31, 513.37s/it]{'loss': 2.304, 'learning_rate': 0.00019934318949361217, 'epoch': 0.09}
  9%|▉         | 109/1240 [15:35:07<161:18:03, 513.42s/it]{'loss': 2.3068, 'learning_rate': 0.00019931232041144306, 'epoch': 0.09}
  9%|▉         | 110/1240 [15:43:41<161:10:53, 513.50s/it]{'loss': 2.3173, 'learning_rate': 0.0001992807449921441, 'epoch': 0.09}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
  9%|▉         | 111/1240 [15:52:15<161:02:48, 513.52s/it]{'loss': 2.3148, 'learning_rate': 0.00019924846346028858, 'epoch': 0.09}
  9%|▉         | 112/1240 [16:00:48<160:55:45, 513.60s/it]{'loss': 2.3051, 'learning_rate': 0.00019921547604547182, 'epoch': 0.09}
  9%|▉         | 113/1240 [16:09:22<160:48:23, 513.67s/it]{'loss': 2.2951, 'learning_rate': 0.00019918178298230954, 'epoch': 0.09}
  9%|▉         | 114/1240 [16:17:56<160:38:04, 513.57s/it]{'loss': 2.2952, 'learning_rate': 0.00019914738451043632, 'epoch': 0.09}
  9%|▉         | 115/1240 [16:26:29<160:29:05, 513.55s/it]{'loss': 2.2944, 'learning_rate': 0.00019911228087450374, 'epoch': 0.09}
  9%|▉         | 116/1240 [16:35:03<160:19:45, 513.51s/it]{'loss': 2.3034, 'learning_rate': 0.00019907647232417873, 'epoch': 0.09}
  9%|▉         | 117/1240 [16:43:36<160:10:56, 513.50s/it]{'loss': 2.2903, 'learning_rate': 0.00019903995911414173, 'epoch': 0.09}
 10%|▉         | 118/1240 [16:52:09<160:01:00, 513.42s/it]{'loss': 2.2638, 'learning_rate': 0.000199002741504085, 'epoch': 0.1}
 10%|▉         | 119/1240 [17:00:42<159:51:20, 513.36s/it]{'loss': 2.281, 'learning_rate': 0.0001989648197587106, 'epoch': 0.1}
 10%|▉         | 120/1240 [17:09:16<159:43:14, 513.39s/it]{'loss': 2.2929, 'learning_rate': 0.0001989261941477287, 'epoch': 0.1}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
 10%|▉         | 121/1240 [17:17:50<159:37:12, 513.52s/it]{'loss': 2.2641, 'learning_rate': 0.00019888686494585542, 'epoch': 0.1}
 10%|▉         | 122/1240 [17:26:23<159:28:05, 513.49s/it]{'loss': 2.2882, 'learning_rate': 0.00019884683243281116, 'epoch': 0.1}
 10%|▉         | 123/1240 [17:34:57<159:18:49, 513.46s/it]{'loss': 2.2781, 'learning_rate': 0.00019880609689331834, 'epoch': 0.1}
 10%|█         | 124/1240 [17:43:30<159:10:37, 513.47s/it]{'loss': 2.2613, 'learning_rate': 0.00019876465861709962, 'epoch': 0.1}
 10%|█         | 125/1240 [17:52:04<159:03:28, 513.55s/it]{'loss': 2.2779, 'learning_rate': 0.00019872251789887564, 'epoch': 0.1}
 10%|█         | 126/1240 [18:00:37<158:54:13, 513.51s/it]{'loss': 2.279, 'learning_rate': 0.00019867967503836302, 'epoch': 0.1}
 10%|█         | 127/1240 [18:09:11<158:44:48, 513.47s/it]{'loss': 2.2687, 'learning_rate': 0.00019863613034027224, 'epoch': 0.1}
 10%|█         | 128/1240 [18:17:44<158:35:06, 513.41s/it]{'loss': 2.2721, 'learning_rate': 0.00019859188411430542, 'epoch': 0.1}
 10%|█         | 129/1240 [18:26:17<158:25:29, 513.35s/it]{'loss': 2.2831, 'learning_rate': 0.00019854693667515418, 'epoch': 0.1}
 10%|█         | 130/1240 [18:34:51<158:17:39, 513.39s/it]{'loss': 2.2578, 'learning_rate': 0.0001985012883424973, 'epoch': 0.1}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
 11%|█         | 131/1240 [18:43:24<158:10:18, 513.45s/it]{'loss': 2.2627, 'learning_rate': 0.00019845493944099855, 'epoch': 0.11}
 11%|█         | 132/1240 [18:51:58<158:02:31, 513.49s/it]{'loss': 2.268, 'learning_rate': 0.00019840789030030437, 'epoch': 0.11}
 11%|█         | 133/1240 [19:00:32<157:56:17, 513.62s/it]{'loss': 2.2706, 'learning_rate': 0.00019836014125504145, 'epoch': 0.11}
 11%|█         | 134/1240 [19:09:05<157:45:55, 513.52s/it]{'loss': 2.2628, 'learning_rate': 0.00019831169264481445, 'epoch': 0.11}
 11%|█         | 135/1240 [19:17:39<157:37:40, 513.54s/it]{'loss': 2.2647, 'learning_rate': 0.00019826254481420343, 'epoch': 0.11}
 11%|█         | 136/1240 [19:26:12<157:28:22, 513.50s/it]{'loss': 2.2777, 'learning_rate': 0.00019821269811276163, 'epoch': 0.11}
 11%|█         | 137/1240 [19:34:45<157:18:07, 513.41s/it]{'loss': 2.2677, 'learning_rate': 0.00019816215289501278, 'epoch': 0.11}
 11%|█         | 138/1240 [19:43:19<157:09:42, 513.41s/it]{'loss': 2.2728, 'learning_rate': 0.00019811090952044866, 'epoch': 0.11}
 11%|█         | 139/1240 [19:51:52<157:03:05, 513.52s/it]{'loss': 2.2577, 'learning_rate': 0.0001980589683535266, 'epoch': 0.11}
 11%|█▏        | 140/1240 [20:00:26<156:55:13, 513.56s/it]{'loss': 2.2622, 'learning_rate': 0.00019800632976366668, 'epoch': 0.11}
C:\Users\removed\PycharmProjects\llmTraining\venv\Lib\site-packages\torch\utils\checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
 11%|█▏        | 141/1240 [20:09:00<156:47:10, 513.59s/it]{'loss': 2.2462, 'learning_rate': 0.00019795299412524945, 'epoch': 0.11}
 11%|█▏        | 142/1240 [20:17:33<156:37:17, 513.51s/it]{'loss': 2.2616, 'learning_rate': 0.0001978989618176129, 'epoch': 0.11}
 12%|█▏        | 143/1240 [20:26:06<156:27:10, 513.43s/it]{'loss': 2.2593, 'learning_rate': 0.00019784423322505, 'epoch': 0.12}
 12%|█▏        | 144/1240 [20:34:40<156:18:04, 513.40s/it]{'loss': 2.2423, 'learning_rate': 0.00019778880873680586, 'epoch': 0.12}
 12%|█▏        | 145/1240 [20:43:13<156:08:35, 513.35s/it]{'loss': 2.2398, 'learning_rate': 0.00019773268874707502, 'epoch': 0.12}
 12%|█▏        | 146/1240 [20:51:46<155:57:40, 513.22s/it]{'loss': 2.247, 'learning_rate': 0.00019767587365499865, 'epoch': 0.12}
 12%|█▏        | 147/1240 [21:00:14<155:22:39, 511.77s/it]{'loss': 0.0, 'learning_rate': 0.00019767587365499865, 'epoch': 0.12}
 12%|█▏        | 148/1240 [21:08:43<154:56:35, 510.80s/it]{'loss': 0.0, 'learning_rate': 0.00019767587365499865, 'epoch': 0.12}
 12%|█▏        | 149/1240 [21:17:11<154:36:06, 510.14s/it]{'loss': 0.0, 'learning_rate': 0.00019767587365499865, 'epoch': 0.12}
 12%|█▏        | 150/1240 [21:25:39<154:17:12, 509.57s/it]{'loss': 0.0, 'learning_rate': 0.00019767587365499865, 'epoch': 0.12}
<The loss stayed at 0.0 and the learning_rate stayed at 0.00019767587365499865 all the way until the final log below>
 29%|██▊       | 355/1240 [50:18:04<124:38:03, 506.99s/it]{'loss': 0.0, 'learning_rate': 0.00019767587365499865, 'epoch': 0.29}
 29%|██▊       | 356/1240 [50:26:30<124:27:32, 506.85s/it]{'loss': 0.0, 'learning_rate': 0.00019767587365499865, 'epoch': 0.29}
 29%|██▉       | 357/1240 [50:34:57<124:19:35, 506.88s/it]{'loss': 0.0, 'learning_rate': 0.00019767587365499865, 'epoch': 0.29}

I just noticed “You are using 8-bit optimizers with a version of bitsandbytes < 0.41.1. It is recommended to update your version as a major bug has been fixed in 8-bit optimizers.”, so will start there, but thanks in advance to anybody with insight into this.

I noticed a similar problem but was running on much smaller model and dataset, where I was able to run for long enough after seeing 0.0 loss in a very short period of time.

In my experiments, this was because there are nan values in your model (and maybe some parts of the loss, weird for some reason). So hf handles it by first outputting a sequence of zeros, and after a while, they became nan loss. I should be able to see this if you are tracking your gradients.

Curious if you had made any progress on this?

See my issue here: TRL SFT super prone to nan when using data collator

1 Like

try using paged_adamw_32bit

1 Like