I’m currently trying to train T5ForConditionalGeneration model for seq2seq task, and I was wondering if we can expect T5 to internally ignore (e.g. by generating attention mask) padding tokens in decoder_input_ids if we don’t explicitly provide decoder_attention_mask?
I noticed from the code that T5 simply create attention mask of all 1s if decoder_attention_mask is not provided, so it seems we’re attending to padding tokens. I also ran a sanity check to see if providing decoder_attention_mask had any meaningful difference for the logits and saw that it does matter.
So I’m wondering if this is by design, because it doesn’t seem to make sense to attend to padding tokens for batched passes.
Below is sanity check that I ran (I know decoder_input_ids is supposed to be different from input_ids normally, but figured it’s not important for this particular issue).
import torch
import transformers
model = transformers.T5ForConditionalGeneration.from_pretrained("t5-base")
tokenizer = transformers.T5Tokenizer.from_pretrained("t5-base")
model.cuda()
model.eval()
texts = ["This is a test input.", "This is a test input to test T5 padding scheme."]
input_ids = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
input_ids.to("cuda")
with torch.inference_mode():
    # Shift decoder input ids to the right
    decoder_input_ids = model._shift_right(input_ids.input_ids)
    # Manually give correct attention mask
    with_attn_mask_logits = model(
        input_ids=input_ids.input_ids,
        attention_mask=input_ids.attention_mask,
        decoder_input_ids=decoder_input_ids,
        decoder_attention_mask=torch.cat((
            torch.tensor([[1], [1]], device="cuda"),
            input_ids.attention_mask[:, :-1]), dim=1
        )
    ).logits
    # Give attention mask of all 1 explicitly
    all_1_attn_mask_logits = model(
        input_ids=input_ids.input_ids,
        attention_mask=input_ids.attention_mask,
        decoder_input_ids=decoder_input_ids,
        decoder_attention_mask=torch.ones(decoder_input_ids.shape, device="cuda"),
    ).logits
    # Do give attention mask at all.
    no_attn_mask_logits = model(
        input_ids=input_ids.input_ids,
        attention_mask=input_ids.attention_mask,
        decoder_input_ids=decoder_input_ids,
    ).logits
    print(torch.all(torch.isclose(with_attn_mask_logits, no_attn_mask_logits)).item())  # False
    print(torch.equal(with_attn_mask_logits, no_attn_mask_logits))  # False
    print(torch.all(torch.isclose(all_1_attn_mask_logits, no_attn_mask_logits)).item())  # True
    print(torch.equal(all_1_attn_mask_logits, no_attn_mask_logits))  # True