Modeling_bert use next-token prediction?

In modeling_bert.py file class BertLMHeadModel(BertPreTrainedModel),

It has following codes for loss implementation:

        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)

        lm_loss = None
        if labels is not None:
            # we are doing next-token prediction; shift prediction scores and input ids by one
            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
            labels = labels[:, 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

Why does BERT do next token prediction here with label shift?

Hi,

Yes for any LLM which you can train in the Transformers library, the model will internally shift the labels one position so that it learns to predict the next token. The convenience of this is that users can just copy the labels from the inputs, i.e. labels = input_ids.clone() - although users then typically also replace tokens which the models shouldn’t learn to predict (like padding tokens) by -100.

Visually (taken from my explanation here):

As can be seen, the labels (top row) are equal to the inputs (bottom row), just shifted one position to the left, and with tokens which the model shouldn’t learn to predict (like the special <|begin_of_text|> in the figure above) replaced by -100.

Thanks a lot @nielsr !

Yes, I understand next token prediction and label shift. But BERT here is not a CLM model, so I am confused why it has label shift. Given its a MLM, I assume it should just do corss entropy over masked tokens and there is no need for shift?

That’s because there were some people interested in initializing decoder-only LLMs with the weights of BERT. This was mainly for the EncoderDecoderModel class, where the weights of the encoder and decoder were both initialized from a pre-trained BERT. See Leveraging Pre-trained Language Model Checkpoints for Encoder-Decoder Models.