I’m getting nan values when I backpropagate 8-bit Llama-3-70B on a simple input. System info and reproduction below. Any help would be greatly appreciated!
System Info
transformers
version: 4.40.1- Platform: Linux-5.15.0-89-generic-x86_64-with-glibc2.31
- Python version: 3.11.7
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.26.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.3.0+cu118 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: True (NVIDIA A100)
- Using distributed or parallel set-up in script?: False
- bitsandbytes version: 0.43.1
Reproduction
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
torch.autograd.set_detect_anomaly(True)
hf_token = ""
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct", trust_remote_code=True, token=hf_token)
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct",
trust_remote_code=True,
token=hf_token,
device_map="auto",
max_memory={0: '80GIB'},
quantization_config=bnb_config,
low_cpu_mem_usage=True)
inputs = torch.tensor([[128000, 128006, 9125, 128007, 271, 16533, 279, 2768, 3488,
449, 1193, 264, 3254, 12360, 596, 836, 323, 912,
5217, 1495, 13, 128009, 128006, 882, 128007, 271, 678,
459, 12360, 304, 279, 5818, 19574, 1369, 1147, 320,
2550, 15, 570, 22559, 449, 1193, 264, 3254, 12360,
596, 836, 323, 912, 5217, 1495, 13, 128009, 128006,
78191, 128007, 271, 4873, 75, 783, 473, 478]]).long().to(model.device)
logits = model(inputs)['logits']
probs = torch.nn.functional.softmax(logits, dim=-1)
loss = -torch.log(probs[0, -1, 263])
loss.backward()
This produces the following error message and stack trace:
/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/autograd/graph.py:744: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
app.launch_new_instance()
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
self.io_loop.start()
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start
self.asyncio_loop.run_forever()
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
self._run_once()
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
handle._run()
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 542, in dispatch_queue
await self.process_one()
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 531, in process_one
await dispatch(*args)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
await result
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 359, in execute_request
await super().execute_request(stream, ident, parent)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 775, in execute_request
reply_content = await reply_content
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 446, in do_execute
res = shell.run_cell(
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
return super().run_cell(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell
result = self._run_cell(
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell
result = runner(coro)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_1296557/2560056571.py", line 34, in <module>
logits = model(inputs)['logits']
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = module._old_forward(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1208, in forward
outputs = self.model(
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = module._old_forward(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1018, in forward
layer_outputs = decoder_layer(
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = module._old_forward(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 738, in forward
hidden_states = self.input_layernorm(hidden_states)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = module._old_forward(*args, **kwargs)
File "/users/htim/miniconda3/envs/minenv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 89, in forward
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:111.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 38
36 print(probs[0, -1, 263])
37 loss = -torch.log(probs[0, -1, 263])
---> 38 loss.backward()
File ~/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
515 if has_torch_function_unary(self):
516 return handle_torch_function(
517 Tensor.backward,
518 (self,),
(...)
523 inputs=inputs,
524 )
--> 525 torch.autograd.backward(
526 self, gradient, retain_graph, create_graph, inputs=inputs
527 )
File ~/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
262 retain_graph = create_graph
264 # The reason we repeat the same comment below is that
265 # some Python versions print out the first line of a multi-line function
266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
268 tensors,
269 grad_tensors_,
270 retain_graph,
271 create_graph,
272 inputs,
273 allow_unreachable=True,
274 accumulate_grad=True,
275 )
File ~/miniconda3/envs/minenv/lib/python3.11/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
742 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
743 try:
--> 744 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
745 t_outputs, *args, **kwargs
746 ) # Calls into the C++ engine to run the backward pass
747 finally:
748 if attach_logging_hooks:
RuntimeError: Function 'MulBackward0' returned nan values in its 1th output.
Expected behavior
No nan values.