Hi,
I was trying to reimplement a simple one beam search for T5 based on the awesome work of Thomas Wolf to understand better how HuggingFace generates new tokens and am a bit bewildered by a discovery. It appears that results improve a lot when I pad the text with pad_token_ids.
Here is a minimum reproducible example of it:
from transformers import T5Tokenizer, T5ForConditionalGeneration
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Function created by Thomas Wolf of the huggingface team
Args:
logits: logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
assert (
logits.dim() == 1
) # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
pretrained_model = 't5-base'
# Loading models
tokenizer = T5Tokenizer.from_pretrained(pretrained_model)
t5_conditional = T5ForConditionalGeneration.from_pretrained(pretrained_model)
encoder, decoder, lm_head = t5_conditional.encoder, t5_conditional.decoder, t5_conditional.lm_head
########### Without Padding ############
generated = torch.tensor(
[tokenizer('translate English to French: I was a victim of a series of accidents.')['input_ids']])
encoded_embeddings = encoder(generated)[0]
for _ in range(16):
decoder_output = decoder(input_ids=generated, encoder_hidden_states=encoded_embeddings)[0]
logits = lm_head(decoder_output)
next_token_logits = logits[0, -1, :]
next_token = torch.argmax(next_token_logits).unsqueeze(0)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
tokenizer.decode(generated[0])
# Ouput generated: ................. (in tokens, a series of 5s)
########### With One Padded "0" ############
generated = torch.tensor(
[tokenizer('translate English to French: I was a victim of a series of accidents.')['input_ids']])
generated = torch.cat((generated, torch.tensor([[0]])), dim=1)
encoded_embeddings = encoder(generated)[0]
for _ in range(16):
decoder_output = decoder(input_ids=generated, encoder_hidden_states=encoded_embeddings)[0]
logits = lm_head(decoder_output)
next_token_logits = logits[0, -1, :]
next_token = torch.argmax(next_token_logits).unsqueeze(0)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
tokenizer.decode(generated[0])
# Ouput generated: â l'occasion de l'accident, j'
########### With Ten Padded "0s" ############
generated = torch.tensor(
[tokenizer('translate English to French: I was a victim of a series of accidents.')['input_ids']])
generated = torch.cat((generated, torch.tensor([[0] * 10])), dim=1)
encoded_embeddings = encoder(generated)[0]
for _ in range(16):
decoder_output = decoder(input_ids=generated, encoder_hidden_states=encoded_embeddings)[0]
logits = lm_head(decoder_output)
next_token_logits = logits[0, -1, :]
next_token = torch.argmax(next_token_logits).unsqueeze(0)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
tokenizer.decode(generated[0])
# Ouput generated: J'ai été victime d'une série d'accidents
The same applies to prompts with or without eos tokens.
Would anyone know why padding improves results so much and what is the optimal padding?
Many thanks,
Abel