Decoder attention mask in text2text/se2seq generation encoder-decoder models

Hi guys!

Suppose I have batch of just two sentences (for simplicity) with different lengths (let len(sent_1) > len(sent_2)). For training I have to provide labels parameter:

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(some_list_of_two_sents, padding=True)

Suppose I got

labels["input_ids"] = [[x,   x,   x,   x,   x,     x,   eos],  # sent_1
                       [x,   x,   x,   x,   eos,   pad, pad]]  # sent_2

In case of training conditional generation models (e.g. T5ForConditionalGeneration or BartForConditionalGeneration), when user omits parameter decoder_input_ids, it will be created automatically by shifting the labels to the right:

decoder_input_ids = shift_tokens_right(labels["input_ids"], self.config.pad_token_id, 
                                       self.config.decoder_start_token_id)

so for our simple example

decoder_input_ids = [[bos,  x,   x,   x,   x,   x,   x ],   # sent_1
                     [bos,  x,   x,   x,   x,  eos, pad]]   # sent_2

The question is - how can I evaluate/deduce parameter decoder_attention_mask ?
So far I see two options here:

  1. decoder_attention_mask = labels["attention_mask"]
  2. decoder_attention_mask = some_manipulations_on(labels["attention_mask"]), maybe right shift as well?

In my opinion 1) should be the case, since difference between labels and decoder_input_ids is following:

lables            = bla_bla + eos + pads
decoder_input_ids = bos + bla_bla + pads

and one can see, that labels["attention_mask"] ignores eos in shifted labels and take into account inserted bos. But I’m not sure, that is why I ask you.

Thanks!

I am currently using VisionEncoderDecoderModel for seq2seq generation. I had chosen 1) without much thought until the same exact question occurred to me. I’m now doubting that approach since I’m not sure that ignoring eos in shifted labels is desirable. Wouldn’t we want the model to attend to eos so that the model learns when to stop?