Hallucination with trainer.evaluate() on LLMs

I am trying to use the trainer to do an evaluation with decoder-only LLMs. I want to have custom metrics with the model. this is specifically related to reloading a quantized Lora checkpoint with do_eval=True and do_train=False.

The issue is model has a lot of hallucinations on predictions, even in the prompt portion. While the same model generates almost perfectly when loaded externally and making predictions using the model.generate() method.

This is specifically on llama-2, but the same issue also shows up with Falcon.
The code is mostly inspired by the run_clm.py example and the Falcon training script.

Here is the code that I came up with:

peft_config = PeftConfig.from_pretrained(script_args.peft_model)
base_model = AutoModelForCausalLM.from_pretrained(#quantization_config=bnb_config
peft_config.base_model_name_or_path,trust_remote_code=True,torch_dtype=torch.bfloat16, device_map={"": 0},load_in_8bit=True)
lora_model = PeftModel.from_pretrained(base_model, script_args.peft_model).to('cuda')

here is the helper functions for metrics

    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        acc_metric = evaluate.load("accuracy")
        acc = acc_metric.compute(predictions=preds, references=labels)

        preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
        decoded_preds = tokenizer.batch_decode(preds)
        return {"ACC":acc}

    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

here is the trainer and evaluation:

trainer = Trainer(
      preprocess_logits_for_metrics = preprocess_logits_for_metrics,
  if training_arguments.do_eval:
      metrics = trainer.evaluate()

Here is an example of what happens on decoded_preds vs. output of using model.generate()

## Exampleru
##ract the from the file text and the above above above.

with default, params expect extending max_new_tokens.

# Instruction
Extract data from the following document using the rules specified below.

library versions:
transformers 4.32.0.dev0
peft 0.5.0.dev0

did you find the solution?