Iām getting complete nonsense when I use Llama-2ās forward function, see below. The output I get with the .generate()
function is a lot better.
The reason I need the forward function is because I have to train my model in a custom PyTorch training loop and as far as I understand, .generate()
canāt be used for training.
Here is a minimal āworkingā example:
import torch
from transformers import BitsAndBytesConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaTokenizer
# model_id = "meta-llama/Llama-2-7b-chat-hf"
model_id = "meta-llama/Llama-2-7b-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = LlamaTokenizer.from_pretrained(model_id)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
model = LlamaForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto", cache_dir="./cache")
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
model.eval()
model_input = tokenizer(
# "Hello, how are you? ###Assistant:",
"Hello, how are you?",
return_tensors="pt",
max_length=20,
truncation=True
# padding="max_length",
)
model_input["input_ids"] = model_input["input_ids"].to("cuda")
model_input["attention_mask"] = model_input["attention_mask"].to("cuda")
model_output = model.generate(model_input['input_ids'], max_new_tokens=50)
# print(model_output)
output_string = tokenizer.batch_decode(model_output)[0]
print("Output with `.generate()`:\n" + output_string)
print("\n")
model_output = model(**model_input)
# print(model_output.logits.shape)
output_string = tokenizer.decode(torch.argmax(model_output.logits.squeeze(), -1))
print("Output with `.forward()`:\n" + output_string)
Output:
Loading checkpoint shards: 100%|āāāāāāāāāā| 2/2 [00:07<00:00, 3.87s/it]
Output with `.generate()`:
<s> Hello, how are you? Iām doing well, thanks for asking. everybody is in good health, so I am happy. I hope you are well too.
Iām very glad that you have visited my website. Iām sure you are looking for a
Output with `.forward()`:
nobody, I are you? I
It might be worth noting that the output from the .generate()
function changes everytime I rerun the script but the forward function always gives me the same gibberish. Sometimes, .generate()
also gives me gibberish that is at least somewhat grammatical, or it slips into (grammatical) German text unprompted.
Iām using 4-bit quantization, could that have anything to do with it?
Iāve been trying to troubleshoot this for two weeks and Iām getting really desperate. Any help would be so much appreciated.