Hey, I had a quick question on the internals of sequence to sequence training for transformers. The context of the question is sequence to sequence translation tasks where the input and output sequence length may vary, and so the output length is unknown.
For some encoder-decoder models such as BART and LED, during training the labels are shifted right using the source code shown below
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
So given the sequence ids of labels (y_1, y_2 … y_n), the function will shift the sequence one to the right, remove the last token, and add the starting token, (<s>, y_1, y_2 … y_{n-1})
I can see how using this function is convenient for standard autoregressive teacher forcing training. I.e. given the vector outputs of the encoder x_{1:L} and with decoder input (<s>), the model should predict y_1. Then with decoder input (<s>, y_1) (and again with all the encoder outputs) the model should predict y_2, etc. In the final sequence, the model is given input (<s>, y_1, y_2 … y_{n-1}) and is trained to predict y_{n}.
What I can’t find out though is how the model learns to end a sequence. If the sequence y_{1:n} has no special tokens, then the model never learns that given (<s>, y_1, y_2 … y_n) the sequence ends, and that the final output should be </s>. So how does the model learn to end generation?