What is loss function for T5

T5 uses the regular cross-entropy loss (as any language model).

Suppose that you are fine-tuning T5 for translation, and you have the following training example:

* source sentence: "hello how are you"
* target sentence: "salut comment ça-va"

First, one needs to tokenize the sentences for the model using T5Tokenizer. Assuming that every word is tokenized into a single token, and we also add T5’s special token (namely </s> - which indicates the end of a sequence), we provide the follow inputs to the model:

* input tokens = [hello, how, are, you, </s>]
* label tokens = [salut, comment, ça, -, va, </s>]

Of course, we don’t provide these tokens as text to the model, but rather as integer IDs, which refer to row indices in an embedding matrix, so the actual inputs will look like:

* input_ids = [21820, 149, 33, 25, 1]
* labels = [20239, 1670, 3664, 18, 900, 1]

In that case, you first provide the input_ids to T5’s encoder, which will turn it into a tensor of shape (batch_size, seq_len, hidden_size). Next, T5’s decoder will predict, for each token of the target sequence, the correct next token. This happens as follows:

      salut         comment      ça          -       va   </s>       => label tokens

      20239          1670        3664        18      900    1        => labels

----------------------------------------------------------------------------------------------                   
                                 DECODER 
----------------------------------------------------------------------------------------------   

         0            20239      1670      3664   18  900  => decoder_input_ids                         

decoder_start_token   salut     comment     ça    -    va  => decoder input tokens

In other words, what happens is, we prepend the decoder inputs with a special token (the decoder start token - which for T5 is the padding token, with index 0), and then the decoder needs to predict (in parallel) that:

  • the token that follows the decoder start token is “salut”. Here, we compute the cross-entropy loss between the prediction of the model and the target token (which is “salut”).
  • the token that follows “salut” is “comment”. Here, we compute the cross-entropy loss between the prediction of the model and the target token (which is “comment”).
  • the token that follows “comment” is “ça”. Here, we compute the cross-entropy loss between the prediction of the model and the target token (which is “ça”).
  • etc.
  • the token that follows “va” is “</s>” (meaning, the end-of-sequence or EOS token). Here, we compute the cross-entropy loss between the prediction of the model and the target token (which is “</s>”).

In the code, this is done in one go, namely by comparing the logits of the model - which are of shape (batch_size, seq_len, vocab_size) - to the ground truth labels:

loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))

4 Likes