Padding with pad_token_id improves results for T5?

Hi,
I was trying to reimplement a simple one beam search for T5 based on the awesome work of Thomas Wolf to understand better how HuggingFace generates new tokens and am a bit bewildered by a discovery. It appears that results improve a lot when I pad the text with pad_token_ids.

Here is a minimum reproducible example of it:

from transformers import T5Tokenizer, T5ForConditionalGeneration

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Function created by Thomas Wolf of the huggingface team
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert (
        logits.dim() == 1
    )  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

pretrained_model = 't5-base'

# Loading models
tokenizer = T5Tokenizer.from_pretrained(pretrained_model)
t5_conditional = T5ForConditionalGeneration.from_pretrained(pretrained_model)
encoder, decoder, lm_head = t5_conditional.encoder, t5_conditional.decoder, t5_conditional.lm_head

########### Without Padding ############
generated = torch.tensor(
    [tokenizer('translate English to French: I was a victim of a series of accidents.')['input_ids']])
encoded_embeddings = encoder(generated)[0]
for _ in range(16):
  decoder_output = decoder(input_ids=generated, encoder_hidden_states=encoded_embeddings)[0]
  logits = lm_head(decoder_output)
  next_token_logits = logits[0, -1, :]
  next_token = torch.argmax(next_token_logits).unsqueeze(0)
  generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)

tokenizer.decode(generated[0])
# Ouput generated: ................. (in tokens, a series of 5s)

########### With One Padded "0" ############
generated = torch.tensor(
    [tokenizer('translate English to French: I was a victim of a series of accidents.')['input_ids']])
generated = torch.cat((generated, torch.tensor([[0]])), dim=1)
encoded_embeddings = encoder(generated)[0]
for _ in range(16):
  decoder_output = decoder(input_ids=generated, encoder_hidden_states=encoded_embeddings)[0]
  logits = lm_head(decoder_output)
  next_token_logits = logits[0, -1, :]
  next_token = torch.argmax(next_token_logits).unsqueeze(0)
  generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)

tokenizer.decode(generated[0])
# Ouput generated:   ⁇  l'occasion de l'accident, j'

########### With Ten Padded "0s" ############
generated = torch.tensor(
    [tokenizer('translate English to French: I was a victim of a series of accidents.')['input_ids']])
generated = torch.cat((generated, torch.tensor([[0] * 10])), dim=1)
encoded_embeddings = encoder(generated)[0]
for _ in range(16):
  decoder_output = decoder(input_ids=generated, encoder_hidden_states=encoded_embeddings)[0]
  logits = lm_head(decoder_output)
  next_token_logits = logits[0, -1, :]
  next_token = torch.argmax(next_token_logits).unsqueeze(0)
  generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)

tokenizer.decode(generated[0])
# Ouput generated: J'ai été victime d'une série d'accidents

The same applies to prompts with or without eos tokens.
Would anyone know why padding improves results so much and what is the optimal padding?

Many thanks,

Abel

Hi, in the first example the source text ids are also passed to the decoder, which they should not. When generating, the decoder sequence should first start with the decoder_start_token and not from the source ids. So when generating for the first step pass the encoder_hidden_states and decoder_start_token_id as the first id. So the correct usage would be

enc = tokenizer(['translate English to French: I was a victim of a series of accidents.'], return_tensors="pt")
input_ids = enc['input_ids']
encoded_embeddings = encoder(input_ids)[0]

# decoder inputs should start from decoder_input_ids,
# for T5 pad_token_id is the deocder_start_token
generated = torch.tensor([[tokenizer.pad_token_id]])

for _ in range(16):
  decoder_output = decoder(input_ids=generated, encoder_hidden_states=encoded_embeddings)[0]
  logits = lm_head(decoder_output)
  next_token_logits = logits[0, -1, :]
  next_token = torch.argmax(next_token_logits).unsqueeze(0)
  generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)

tokenizer.decode(generated[0])

which generates J'ai été victime d'une série d'accident.

Also when passing padded text, we should also pass attention_mask so the pad tokens won’t be attended. The tokenizer returns attention_mask along with the input_ids.

And to create tensors for the tokenized ids, pass the return_tensors argument to the tokenizer which will then return tensor instead of list depending on the value of return_tensor which is ‘pt’ for torch tensors, ‘tf’ for tf tensors. If you pass a list of strings, tokenizer will automatically batch them.

1 Like

Hi, thank you so much for the thoughtful answer, you answered every question I had and more! I made the change to only passing input_ids which have been generated, and it’s purring!