Difference in trainer.predict() and model.generate() for LM

Hi, I am training llama with trainer. I woulld like to get generation on training data with trainer.predict(). I apply argmax to the raw predictions for decoding, which I assume should be equivalent to greedy decoding. But the decoded result is different from the model.generate() given the same input. Additionally, the trainer.predict() only generate one more word apart from repeating the input, while model.generate() produces much more words. So how does trainer.predict() actually work?

related codes for decode trainer.predict()

 def logit2seqGreedy(pred):
   return np.argmax(pred, axis=-1)

  for i in range(predictions.shape[0]):
    decode_pred=logit2seqGreedy(predictions[i])
    pred_tokens = np.where(decode_pred != 0, decode_pred, tokenizer.pad_token_id).reshape(1,-1)
    pred_tokens = tokenizer.batch_decode(pred_tokens, skip_special_tokens=True)

codes for model.generate()

from transformers import GenerationConfig
generation_config = GenerationConfig(
temperature=0,
top_p=1,
num_beams=1,
)
load_pretrained=0
if load_pretrained:
model=LLaMAForCausalLM.from_pretrained(output)

def tokenize(prompt, add_eos_token=True):
result = tokenizer(prompt, return_tensors=“pt”)
for i in range(len(prompt)):
if add_eos_token and result[“input_ids”][i][-1] != tokenizer.eos_token_id:
result[“input_ids”][i].append(tokenizer.eos_token_id)
result[“attention_mask”][i].append(1)
return result

def evaluate(prompt):
#inputs = tokenizer(prompt, return_tensors=“pt”)
inputs=tokenize([prompt],0)
input_ids = torch.tensor(inputs[“input_ids”])

generation_output = model.generate(
    input_ids=input_ids,
    generation_config=generation_config,
    return_dict_in_generate=True,
    output_scores=True,
    max_new_tokens=150
)

for s in generation_output.sequences:
    output = tokenizer.decode(s)
    return output

Data input template for trainer.predict() and model.generate()

def getInps_eval(d):
     return f"""### Input:
        {d['que']} ### Response: The answer is ans"""
2 Likes