Compute log probabilities of any sequence provided

Hey there!

I’m using allenai/unifiedqa-t5-small model to obtain the log probabilities of a given sequence (which is not necessarily the one generated by the model). In particular, I’m interested in having the probability distribution that is conditioned on the previous tokens in the sequence.

So far, I’ve been using the forward method and providing the sentence I want to obtain the logits for as the labels argument. However, I am not 100% confident the resulting logits are conditioned on the sentence I provided as labels (i.e., is the forward method working in a teacher forcing fashion and, thus guarantees that log probabilities are conditioned on the labels parameter, rather than the argmax)?

A second question is whether it is necessary to provide the pad_input_id character as the first character in the labels argument.

2 Likes

Hey @PastelBelem8,

Would this PR: [Generation] Fix Transition probs by patrickvonplaten · Pull Request #17311 · huggingface/transformers · GitHub and this forum post: Generation Probabilities: How to compute probabilities of output scores for GPT2 help?

Hey @PastelBelem8, have you figured out how to get the log probabilities of any given sequence yet?

I’m trying to implement the same thing using the T5 model. My approach was to use .compute_transition_scores() and reconstruct the sequence score.

As written in the documentation, .compute_transition_scores() takes sequences and scores and returns the transition score for each token id in the sequences provided. The snippet below takes the output of the model.generate() and get transitions scores on the generated example. However, we could pass token_ids for the wanted sequence to .compute_transition_scores() instead of outputs.sequences.

from transformers import AutoModelWithLMHead, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("llenai/unifiedqa-t5-small")
model = AutoModelWithLMHead.from_pretrained("llenai/unifiedqa-t5-small")
inputs = tokenizer(["question: <q> context: <c>"], return_tensors="pt")

outputs = model.generate(inputs.input_ids, return_dict_in_generate=True, output_scores=True)

# instead of this
transition_scores = model.compute_transition_scores(
    outputs.sequences, outputs.scores, normalize_logits=False
)

# do this
wanted_seq = tokenizer.batch_encode_plus(["wanted sequence"], return_tensors="pt").input_ids
wanted_seq = torch.cat([torch.tensor([[0]]), tar_input_ids], dim=1)
transition_scores = model.compute_transition_scores(
    wanted_seq, outputs.scores, normalize_logits=False
)

However, I’ve realized when the length of the wanted_seq > length of the generated sequence from .generate(), outputs.scores does not cover the full length of the wanted_sequence input_ids. So I’ve been trying to modify the behavior of .generate() to get the full length of scores by passing stopping_criteria. However, I still haven’t figured out how to do so.

What I’d like to have is the scores of the length I wanted (up to 32 in the case below, without early stopping). Can @patrickvonplaten help me with this? I’ve tried below but it’s throwing me an error.

from transformers.generation.stopping_criteria import StoppingCriteriaList, MaxLengthCriteria

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(32)])
            outputs = model.generate(
                inputs['input_ids'].cuda() if torch.cuda.is_available() else inputs["input_ids"],
                num_beams=args.num_beams,
                max_length=32,
                stopping_criteria=stopping_criteria,
                early_stopping=False,
                output_scores=True,
                return_dict_in_generate=True,
            )
1 Like

Hey!

I did find a way to compute those scores! I think the new release of HuggingFace had significant changes in terms of computing scores for sequences (I haven’t tried computing the scores yet).

If you still want to use your method I would suggest you try specifying the argument for min_length during generate which leads to generations longer than that specified number.

From what I gathered you’re trying to compute the log probabilities of any provided sequence, right? Over the time I’ve been iterating over what other ways we can use to compute those probabilities and I gathered them in this notebook.

I hope this helps :slight_smile:

2 Likes

Hey @Cbelem! Thank you for sharing the notebook!

Hey @Cbelem,

Thank you for sharing this notebook.

I wonder if you have compared the sequences_scores output by the generate() method with the loss output by the forward() when you provide the generated sequence as labels. Assuming forward() does teacher forcing as it should, the loss should be the negative log-probability of the sequence normalized by its length, so I was expecting loss * sequence_length and sequences_scores to be the same in absolute value, but it turns out they are not.

Minimal example:

import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

# load a T5-small model
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)
tokenizer = T5Tokenizer.from_pretrained('t5-small', model_max_length=512)
model.eval()

# define some source text and tokenize it
source_text = "This is a source sentence."
source_ids = tokenizer(source_text, return_tensors="pt").input_ids.to(device)

# generate the output using beam search
gen_outputs = model.generate(
    inputs=source_ids,
    num_beams=2,
    min_length=0,
    max_length=512,
    length_penalty=0,
    output_scores=True,
    return_dict_in_generate=True,
)

# compute the loss for the generated sequence
loss = model(
    input_ids=source_ids,
    attention_mask=torch.ones_like(source_ids),
    labels=gen_outputs.sequences,
    return_dict=True
).loss.item()

# compare the scores given by generate() with the loss given by forward()
print('scores:', gen_outputs.sequences_scores.item())
print('loss * seq_len:', loss * gen_outputs.sequences.shape[-1])
print('loss:', loss)

This will output:

scores: -3.2493550777435303
loss * seq_len: 13.989073991775513
loss: 1.5543415546417236

I don’t see any reason for the mismatch. Am I missing something obvious?

Hi @dpernes !

This is a great question! In the past, I’ve seen similar threads about sequences-scores do not compute what is expected for beam_search and also this one about the mismatch between the log probabilities of the beam search and the compute transitions. I believe they updated this API and is now easier to get these scores.

I’m going to look into your examples and try to understand a bit better what’s going on. What transformers version are you using?

Hi @Cbelem! Thank you for your help :slight_smile:

I believe they updated this API and is now easier to get these scores.

Yes, I tried with the new function compute_transition_scores and the scores match those provided by generate, but the mismatch with the loss persists. Maybe @joaogante can explain the mismatch.

What transformers version are you using?

I am using v4.26.1

Minimal example (updated with compute_transition_scores):

import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

# load a T5-small model
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)
tokenizer = T5Tokenizer.from_pretrained('t5-small', model_max_length=512)
model.eval()

# define some source text and tokenize it
source_text = "This is a source sentence."
source_ids = tokenizer(source_text, return_tensors="pt").input_ids.to(device)

# generate the output using beam search
gen_outputs = model.generate(
    inputs=source_ids,
    num_beams=2,
    min_length=0,
    max_length=512,
    length_penalty=0,
    output_scores=True,
    return_dict_in_generate=True,
)

# compute the scores using compute_transition_scores()
scores = model.compute_transition_scores(
    sequences=gen_outputs.sequences,
    scores=gen_outputs.scores,
    beam_indices=gen_outputs.beam_indices,
)

# compute the loss for the generated sequence
loss = model(
    input_ids=source_ids,
    attention_mask=torch.ones_like(source_ids),
    labels=gen_outputs.sequences,
    return_dict=True
).loss.item()

# compare the scores given by generate() with the loss given by forward()
print('scores (generate):', gen_outputs.sequences_scores.item())
print('scores (compute_transition_scores):', scores.sum().item())
print('loss * seq_len:', loss * gen_outputs.sequences.shape[-1])
print('loss:', loss)

Output:

scores (generate): -3.2493550777435303
scores (compute_transition_scores): -3.2493550777435303
loss * seq_len: 13.989073991775513
loss: 1.5543415546417236

Before replying to the other points, I want to highlight the following:

:warning: Please do not use the method in this comment (cc @Seohyeong) to get scores for an arbitrary sequence.

Text generation is auto-regressive: it predicts the next tokens based on the tokens predicted so far. The scores field in the output contains the logits for all tokens in the vocabulary at each position. Yes, you can obtain a score for any token. But that score is only correct if the preceding tokens are exactly the same! In this colab you will see that if we change the model inputs (i.e. the source of the scores), their values will change for a few selected tokens.

:point_right: See here for an example on how to compute the token-level scores for any sequence :slight_smile:

2 Likes

Hey @dpernes :wave: A tiny correction is needed in your example, and we get the matching numbers as expected. In a nutshell, gen_outputs.sequences contains the BOS token, which is not predicted by the model – all outputs start with it. So, when setting labels, we have to skip it :smiley:

Translated to code, replace the last lines of your script with

# compute the loss for the generated sequence
loss = model(
    input_ids=source_ids,
    attention_mask=torch.ones_like(source_ids),
    labels=gen_outputs.sequences[:, 1:],  # skip BOS token
    return_dict=True
).loss.item()

# compare the scores given by generate() with the loss given by forward()
print('scores (generate):', gen_outputs.sequences_scores.item())
print('scores (compute_transition_scores):', scores.sum().item())
print('loss * seq_len:', loss * (gen_outputs.sequences.shape[-1] - 1))  # correct length
print('loss:', loss)

The whole script now outputs:

scores (generate): -3.249357223510742
scores (compute_transition_scores): -3.2493574619293213
loss * seq_len: 3.2493574619293213
loss: 0.40616968274116516
2 Likes

Oh, I didn’t realize you were including the BOS token in the generated sequence. Makes a lot of sense, now, thank you! :slight_smile:

Thank you so much!

If I were to calculate the perplexity score of the generated sequence, exp(loss) based on your calculation should be correct, right?

I am trying to reproduce the sentiment generation task in the DPO paper.

It states that the reward can be modeled as below, with π(y|x) being the sequence probability from a finetuned policy model.

How do I compute this π(y|x) ? I am assuming that I can get this by calling generate on the finetuned model and computing sequence probabilities.

I have this current setup to sample 4 sequences and get the sequence probability for all 4 generations:

# Load a GPT-2 model and tokenizer
prompt = ["Enchanted was a movie "]
input_ids = loaded_tokenizer.encode(prompt, return_tensors="pt").to(device)
output = loaded_model.generate(
    input_ids, 
    max_length=100, 
    top_k=50,
    top_p=0.95,
    do_sample=True,
    num_return_sequences=4,
    return_dict_in_generate=True, 
    output_scores=True,
    pad_token_id=50256)


logits = loaded_model.compute_transition_scores(
    output.sequences, output.scores, normalize_logits=True)

log_probs = logits.sum(dim=1)

Am I correct in assuming that logits are log-probs for each generated token for each sequence? And if I want the overall probability for each generated sequence, I can sum up the logits for each sequence?

@joaogante could we use that method if we have the same prompts (preceding tokens) and target sequence but different models? In other words, for different models, I’d like to compute log probs for a sequence given the prompt.

1 Like

Hi @jaydeepb, I am trying to do the same, I want to get log probabilities of an output sentence given a prompt, my model of interest is LLaMA-2. have you figured how to do it? I would appreciate it if you could share the script.

Can someone double check if what I am doing here is correct?

input_tokens = tokenizer.encode(source_text, add_special_tokens=False, return_tensors="pt").to(device)
input_tokens_updated = input_tokens.clone().to(torch.int64).to(device)
output_tokens = tokenizer.encode(tgt_text, add_special_tokens=False, return_tensors="pt")[0].to(device)
log_sum = 0

for i in range(len(output_tokens)):
    # Predict with the given model
    with torch.no_grad():
        outputs = model.generate(input_tokens_updated, max_new_tokens=1, output_logits=True, return_dict_in_generate=True)
        logit_predictions = outputs.logits[0]

    # Extract the log probability of the output token
    token = tokenizer.decode(output_tokens[i])
    log_probs = torch.nn.functional.log_softmax(logit_predictions, dim=-1)
    out_token_logit = logit_predictions[0, output_tokens[i]]
    out_token_log_prob = log_probs[0, output_tokens[i]]
    log_sum += out_token_log_prob
    print(f"Token: {token}", "logit: ", out_token_logit, "log prob: ", out_token_log_prob)

    # Incrementally add an output token to the current sequence
    input_tokens_updated = torch.cat([input_tokens_updated, output_tokens[i].reshape(1, 1)], dim=1)
    print([tokenizer.decode(token) for token in input_tokens_updated])
    print("============")
print(f"Total Log Sum Probability: {log_sum}")

at each step, I am generating one token and calculating the log probability of the target token I am interested in. I assume that the final log_sum represents the log probability of the target sentence.

Hi, assuming your batch of input sequence batch is left-padded, i.e. each row is like <pad>{context}{response}, where <pad> represents 0 to many padding tokens. And we would like to compute the probability that model generates response conditioned on context: model(response|context) or the log probability of that.

I believe a nicer implementation is the following, assuming you already have two tensors:

  • input_ids (batch_size, seq_len): the token ids of your input batch of context+responses, left padded, of shape
  • context_lengths (batch_size, 1): the length of the pad tokens + context for each sequence
import torch
from trl.trainer.utils import forward

raw_outputs = forward(model, input_ids, pad_token_id)  # Raw logits
token_distribution = torch.log_softmax(raw_outputs.logits, dim=-1)  # transform to log probs
sequence_logprob = torch.gather(reward_all_logprob, 2, input_ids.unsqueeze(-1)).squeeze(-1)  # Flatten to the entire sequence log probs

# Now we only need to sum the part from actual response using a mask from context length
indices = torch.arange(sequence_logprob.shape[1], device=sequence_logprob.device).repeat(sequence_logprob.shape[0], 1)
mask = (indices >= context_lengths) * (indices < sequence_logprob.shape[1] - 1)
response_logprob = (sequence_logprob * mask).sum(1).unsqueeze(1)  # (batch_size, 1)

If your batch has a constant context length across all rows, then the masking can be avoided by directly slicing the raw logits which saves some computation.

Also, the helper function forward is defined as, which requires all sequences are left padded. However I am not sure that torch.masked_fill is necessary here. In my opinion this step is redundant.

def forward(
    model: torch.nn.Module,
    query_responses: torch.Tensor,
    pad_token_id: int,
) -> torch.nn.Module:
    """
    Performs a forward pass through the model with the given query responses and pad token ID.

    Args:
        model (`torch.nn.Module`):
            The model to perform the forward pass.
        query_responses (`torch.Tensor`):
            The tensor containing the query responses.
        pad_token_id (`int`):
            The token ID representing the pad token.

    Returns:
        `torch.nn.Module`:
            The output of the model, including hidden states.
    """
    attention_mask = query_responses != pad_token_id
    position_ids = attention_mask.cumsum(1) - attention_mask.long()
    input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
    return model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        return_dict=True,
        output_hidden_states=True,
    )