How do Sequence to Sequence architectures (BART, LED) learn the end of generation?

Hey, I had a quick question on the internals of sequence to sequence training for transformers. The context of the question is sequence to sequence translation tasks where the input and output sequence length may vary, and so the output length is unknown.

For some encoder-decoder models such as BART and LED, during training the labels are shifted right using the source code shown below

def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

So given the sequence ids of labels (y_1, y_2y_n), the function will shift the sequence one to the right, remove the last token, and add the starting token, (<s>, y_1, y_2y_{n-1})

I can see how using this function is convenient for standard autoregressive teacher forcing training. I.e. given the vector outputs of the encoder x_{1:L} and with decoder input (<s>), the model should predict y_1. Then with decoder input (<s>, y_1) (and again with all the encoder outputs) the model should predict y_2, etc. In the final sequence, the model is given input (<s>, y_1, y_2y_{n-1}) and is trained to predict y_{n}.

What I can’t find out though is how the model learns to end a sequence. If the sequence y_{1:n} has no special tokens, then the model never learns that given (<s>, y_1, y_2y_n) the sequence ends, and that the final output should be </s>. So how does the model learn to end generation?

Hello :slightly_smiling_face:

I can give my 2 cents on this question:

The original sequence is (y_1, \ldots, y_n, \texttt{</s>}), and indeed we are shifting the this input to the right when feeding it to the decoder, but we need to remember two things:

  1. Shifting to the right only gets rid of the </s> token.
  2. </s> is part of the vocabulary, so it has an entry in the decoder’s softmax output.
  3. When we input the last token from the shifted sequence, y_n, we expect the output to be </s>, we know that it has to be </s> since this is a supervised task, otherwise, no loss.

With the above in mind, the model will eventually learn when to end sequences. Naturally, this isn’t perfect, which is why I think beam search also needs a maximum length parameter.

Hi beneyal, thanks for your response!

So I take it then that in formatting the labels you include the eos token at the end, but not the sos id at the input?

Though it does feel a bit strange that the model doesn’t keep explicit information of the eos token, for as you said, in decoding the generated sequences can go beyond the final token and skew the probabilities.

Anyways, again thanks for your response, I’ll stick with this!