Code:
from accelerate import init_empty_weights
from transformers import OPTForCausalLM, AutoTokenizer
import torch
with init_empty_weights():
model = OPTForCausalLM.from_pretrained(
"facebook/opt-1.3b",
device_map="auto",
offload_folder="/tmp/opt-1.3b-offload-accelerate",
)
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
inputs = tokenizer("Hello, my name is", return_tensors="pt")
with torch.no_grad():
inputs = inputs.to(0)
output = model.generate(inputs["input_ids"])
print(tokenizer.decode(output[0].tolist()))
Traceback:
Traceback (most recent call last):
File "run_inference.py", line 6, in <module>
model = OPTForCausalLM.from_pretrained(
File "/home/sh0416/anaconda3/envs/personal/lib/python3.8/site-packages/transformers/modeling_utils.py", line 2529, in from_pretrained
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
File "/home/sh0416/anaconda3/envs/personal/lib/python3.8/site-packages/accelerate/big_modeling.py", line 318, in dispatch_model
attach_align_device_hook_on_blocks(
File "/home/sh0416/anaconda3/envs/personal/lib/python3.8/site-packages/accelerate/hooks.py", line 488, in attach_align_device_hook_on_blocks
attach_align_device_hook_on_blocks(
File "/home/sh0416/anaconda3/envs/personal/lib/python3.8/site-packages/accelerate/hooks.py", line 464, in attach_align_device_hook_on_blocks
add_hook_to_module(module, hook)
File "/home/sh0416/anaconda3/envs/personal/lib/python3.8/site-packages/accelerate/hooks.py", line 148, in add_hook_to_module
module = hook.init_hook(module)
File "/home/sh0416/anaconda3/envs/personal/lib/python3.8/site-packages/accelerate/hooks.py", line 237, in init_hook
set_module_tensor_to_device(module, name, self.execution_device)
File "/home/sh0416/anaconda3/envs/personal/lib/python3.8/site-packages/accelerate/utils/modeling.py", line 127, in set_module_tensor_to_device
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
ValueError: weight is on the meta device, we need a `value` to put in on 1.
How to resolve this error…? I don’t know how to do it…