Compute log probabilities of any sequence provided

Hi @Cbelem! Thank you for your help :slight_smile:

I believe they updated this API and is now easier to get these scores.

Yes, I tried with the new function compute_transition_scores and the scores match those provided by generate, but the mismatch with the loss persists. Maybe @joaogante can explain the mismatch.

What transformers version are you using?

I am using v4.26.1

Minimal example (updated with compute_transition_scores):

import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

# load a T5-small model
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)
tokenizer = T5Tokenizer.from_pretrained('t5-small', model_max_length=512)
model.eval()

# define some source text and tokenize it
source_text = "This is a source sentence."
source_ids = tokenizer(source_text, return_tensors="pt").input_ids.to(device)

# generate the output using beam search
gen_outputs = model.generate(
    inputs=source_ids,
    num_beams=2,
    min_length=0,
    max_length=512,
    length_penalty=0,
    output_scores=True,
    return_dict_in_generate=True,
)

# compute the scores using compute_transition_scores()
scores = model.compute_transition_scores(
    sequences=gen_outputs.sequences,
    scores=gen_outputs.scores,
    beam_indices=gen_outputs.beam_indices,
)

# compute the loss for the generated sequence
loss = model(
    input_ids=source_ids,
    attention_mask=torch.ones_like(source_ids),
    labels=gen_outputs.sequences,
    return_dict=True
).loss.item()

# compare the scores given by generate() with the loss given by forward()
print('scores (generate):', gen_outputs.sequences_scores.item())
print('scores (compute_transition_scores):', scores.sum().item())
print('loss * seq_len:', loss * gen_outputs.sequences.shape[-1])
print('loss:', loss)

Output:

scores (generate): -3.2493550777435303
scores (compute_transition_scores): -3.2493550777435303
loss * seq_len: 13.989073991775513
loss: 1.5543415546417236