T5 uses the regular cross-entropy loss (as any language model).
Suppose that you are fine-tuning T5 for translation, and you have the following training example:
* source sentence: "hello how are you"
* target sentence: "salut comment ça-va"
First, one needs to tokenize the sentences for the model using T5Tokenizer
. Assuming that every word is tokenized into a single token, and we also add T5’s special token (namely </s> - which indicates the end of a sequence), we provide the follow inputs to the model:
* input tokens = [hello, how, are, you, </s>]
* label tokens = [salut, comment, ça, -, va, </s>]
Of course, we don’t provide these tokens as text to the model, but rather as integer IDs, which refer to row indices in an embedding matrix, so the actual inputs will look like:
* input_ids = [21820, 149, 33, 25, 1]
* labels = [20239, 1670, 3664, 18, 900, 1]
In that case, you first provide the input_ids
to T5’s encoder, which will turn it into a tensor of shape (batch_size, seq_len, hidden_size)
. Next, T5’s decoder will predict, for each token of the target sequence, the correct next token. This happens as follows:
salut comment ça - va </s> => label tokens
20239 1670 3664 18 900 1 => labels
----------------------------------------------------------------------------------------------
DECODER
----------------------------------------------------------------------------------------------
0 20239 1670 3664 18 900 => decoder_input_ids
decoder_start_token salut comment ça - va => decoder input tokens
In other words, what happens is, we prepend the decoder inputs with a special token (the decoder start token - which for T5 is the padding token, with index 0), and then the decoder needs to predict (in parallel) that:
- the token that follows the decoder start token is “salut”. Here, we compute the cross-entropy loss between the prediction of the model and the target token (which is “salut”).
- the token that follows “salut” is “comment”. Here, we compute the cross-entropy loss between the prediction of the model and the target token (which is “comment”).
- the token that follows “comment” is “ça”. Here, we compute the cross-entropy loss between the prediction of the model and the target token (which is “ça”).
- etc.
- the token that follows “va” is “</s>” (meaning, the end-of-sequence or EOS token). Here, we compute the cross-entropy loss between the prediction of the model and the target token (which is “</s>”).
In the code, this is done in one go, namely by comparing the logits of the model - which are of shape (batch_size, seq_len, vocab_size) - to the ground truth labels:
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))