Generation Probabilities: How to compute probabilities of output scores for GPT2

Now that it is possible to return the logits generated at each step, one might wonder how to compute the probabilities for each generated sequence accordingly.

The following code snippet showcases how to do so for generation with do_sample=True for GPT2:

import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer


gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

input_ids = tokenizer("Today is a nice day", return_tensors="pt").input_ids

generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True)

# only use id's that were generated
# gen_sequences has shape [3, 15]
gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:]

# let's stack the logits generated at each step to a tensor and transform
# logits to probs
probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1)  # -> shape [3, 15, vocab_size]

# now we need to collect the probability of the generated token
# we need to add a dummy dim in the end to make gather work
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)

# now we can do all kinds of things with the probs

# 1) the probs that exactly those sequences are generated again
# those are normally going to be very small
unique_prob_per_sequence = gen_probs.prod(-1)

# 2) normalize the probs over the three sequences
normed_gen_probs = gen_probs / gen_probs.sum(0)
assert normed_gen_probs[:, 0].sum() == 1.0, "probs should be normalized"

# 3) compare normalized probs to each other like in 1)
unique_normed_prob_per_sequence = normed_gen_probs.prod(-1)
2 Likes

Can I use this to generate sequences only over a probability threshold?

Great to see this very needed feature.

I want to try it out but with transformers 4.2.0 and I see error like “TypeError: forward() got an unexpected keyword argument 'return_dict_in_generate”.

It has to be used with generate() - not with forward() :slight_smile:

No I don’t think so sadly. Such a feature would be very hard to implement though