I’m currently trying to train T5ForConditionalGeneration model for seq2seq task, and I was wondering if we can expect T5 to internally ignore (e.g. by generating attention mask) padding tokens in decoder_input_ids
if we don’t explicitly provide decoder_attention_mask
?
I noticed from the code that T5 simply create attention mask of all 1s if decoder_attention_mask
is not provided, so it seems we’re attending to padding tokens. I also ran a sanity check to see if providing decoder_attention_mask
had any meaningful difference for the logits and saw that it does matter.
So I’m wondering if this is by design, because it doesn’t seem to make sense to attend to padding tokens for batched passes.
Below is sanity check that I ran (I know decoder_input_ids
is supposed to be different from input_ids
normally, but figured it’s not important for this particular issue).
import torch
import transformers
model = transformers.T5ForConditionalGeneration.from_pretrained("t5-base")
tokenizer = transformers.T5Tokenizer.from_pretrained("t5-base")
model.cuda()
model.eval()
texts = ["This is a test input.", "This is a test input to test T5 padding scheme."]
input_ids = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
input_ids.to("cuda")
with torch.inference_mode():
# Shift decoder input ids to the right
decoder_input_ids = model._shift_right(input_ids.input_ids)
# Manually give correct attention mask
with_attn_mask_logits = model(
input_ids=input_ids.input_ids,
attention_mask=input_ids.attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=torch.cat((
torch.tensor([[1], [1]], device="cuda"),
input_ids.attention_mask[:, :-1]), dim=1
)
).logits
# Give attention mask of all 1 explicitly
all_1_attn_mask_logits = model(
input_ids=input_ids.input_ids,
attention_mask=input_ids.attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=torch.ones(decoder_input_ids.shape, device="cuda"),
).logits
# Do give attention mask at all.
no_attn_mask_logits = model(
input_ids=input_ids.input_ids,
attention_mask=input_ids.attention_mask,
decoder_input_ids=decoder_input_ids,
).logits
print(torch.all(torch.isclose(with_attn_mask_logits, no_attn_mask_logits)).item()) # False
print(torch.equal(with_attn_mask_logits, no_attn_mask_logits)) # False
print(torch.all(torch.isclose(all_1_attn_mask_logits, no_attn_mask_logits)).item()) # True
print(torch.equal(all_1_attn_mask_logits, no_attn_mask_logits)) # True