Loss.backward() producing nan values with 8-bit Llama-3-70B-Instruct

I’m getting nan values when I backpropagate 8-bit Llama-3-70B on a simple input. System info and reproduction below. Any help would be greatly appreciated!

System Info

  • transformers version: 4.40.1
  • Platform: Linux-5.15.0-89-generic-x86_64-with-glibc2.31
  • Python version: 3.11.7
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.26.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: True (NVIDIA A100)
  • Using distributed or parallel set-up in script?: False
  • bitsandbytes version: 0.43.1

Reproduction

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
torch.autograd.set_detect_anomaly(True)

hf_token = ""
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct", trust_remote_code=True, token=hf_token)
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct",
                                             trust_remote_code=True, 
                                             token=hf_token,
                                             device_map="auto",
                                             max_memory={0: '80GIB'}, 
                                             quantization_config=bnb_config,
                                             low_cpu_mem_usage=True)

inputs = torch.tensor([[128000, 128006,   9125, 128007,    271,  16533,    279,   2768,   3488,
            449,   1193,    264,   3254,  12360,    596,    836,    323,    912,
           5217,   1495,     13, 128009, 128006,    882, 128007,    271,    678,
            459,  12360,    304,    279,   5818,  19574,   1369,   1147,    320,
           2550,     15,    570,  22559,    449,   1193,    264,   3254,  12360,
            596,    836,    323,    912,   5217,   1495,     13, 128009, 128006,
          78191, 128007,    271,   4873,     75,    783,    473,    478]]).long().to(model.device)

logits = model(inputs)['logits']
probs = torch.nn.functional.softmax(logits, dim=-1)
loss = -torch.log(probs[0, -1, 263])
loss.backward()

This produces the following error message and stack trace:

/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/autograd/graph.py:744: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
    self._run_once()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
    handle._run()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 542, in dispatch_queue
    await self.process_one()
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 531, in process_one
    await dispatch(*args)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    await result
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 359, in execute_request
    await super().execute_request(stream, ident, parent)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 775, in execute_request
    reply_content = await reply_content
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 446, in do_execute
    res = shell.run_cell(
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell
    result = self._run_cell(
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell
    result = runner(coro)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_1296557/2560056571.py", line 34, in <module>
    logits = model(inputs)['logits']
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1208, in forward
    outputs = self.model(
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1018, in forward
    layer_outputs = decoder_layer(
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 738, in forward
    hidden_states = self.input_layernorm(hidden_states)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 89, in forward
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:111.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 38
     36 print(probs[0, -1, 263])
     37 loss = -torch.log(probs[0, -1, 263])
---> 38 loss.backward()

File ~/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    515 if has_torch_function_unary(self):
    516     return handle_torch_function(
    517         Tensor.backward,
    518         (self,),
   (...)
    523         inputs=inputs,
    524     )
--> 525 torch.autograd.backward(
    526     self, gradient, retain_graph, create_graph, inputs=inputs
    527 )

File ~/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    262     retain_graph = create_graph
    264 # The reason we repeat the same comment below is that
    265 # some Python versions print out the first line of a multi-line function
    266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
    268     tensors,
    269     grad_tensors_,
    270     retain_graph,
    271     create_graph,
    272     inputs,
    273     allow_unreachable=True,
    274     accumulate_grad=True,
    275 )

File ~/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
    742     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    743 try:
--> 744     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    745         t_outputs, *args, **kwargs
    746     )  # Calls into the C++ engine to run the backward pass
    747 finally:
    748     if attach_logging_hooks:

RuntimeError: Function 'MulBackward0' returned nan values in its 1th output.

Expected behavior

No nan values.

1 Like

Hi,

Could you try by just getting the loss from the model (by passing input_ids and labels to the model)? This way you get outputs.loss:

labels = input_ids.copy()
labels[labels == pad_token_id] = -100

outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()

I am still getting the same error and stack trace when using that method. I am only interested in the loss when predicting the last token of the input (263), so the code I used to do this based on your suggestion is below:

input_ids = torch.tensor([[128000, 128006,   9125, 128007,    271,  16533,    279,   2768,   3488,
            449,   1193,    264,   3254,  12360,    596,    836,    323,    912,
           5217,   1495,     13, 128009, 128006,    882, 128007,    271,    678,
            459,  12360,    304,    279,   5818,  19574,   1369,   1147,    320,
           2550,     15,    570,  22559,    449,   1193,    264,   3254,  12360,
            596,    836,    323,    912,   5217,   1495,     13, 128009, 128006,
          78191, 128007,    271,   4873,     75,    783,    473,    478,  263]]).long().to(model.device)

labels = torch.zeros(input_ids.shape)-100
labels[0,-1] = 263
labels = labels.long().to(model.device)

outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()

Update, this is not supported for quantized models: loss.backward() producing nan values with 8-bit Llama-3-70B-Instruct · Issue #30526 · huggingface/transformers · GitHub