opened 03:54PM - 22 Aug 24 UTC
Feature request
### Feature request
i believe `labels` in the training of causal LMs means the …value to predict at time `n`, i.e., the next token. in other words, i'd assume, if `labels` is given, it should be already shifted by one in the data loader w.r.t. the `input_ids`.
however, in `LlamaForCausalLM.forward()`, i found the labels are always shifted, silently.
https://github.com/huggingface/transformers/blob/f1d822ba337499d429f832855622b97d90ac1406/src/transformers/models/llama/modeling_llama.py#L1205-L1210
```python
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
```
...
```python
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
```
i found it quite unexpected hence calling it "silently". as this is for a causal LM, shouldn't it be not shifting the labels by default? in modeling GPT2, this is at least documented explicitly.
https://github.com/huggingface/transformers/blob/f1d822ba337499d429f832855622b97d90ac1406/src/transformers/models/gpt2/modeling_gpt2.py#L1309-1314
in gemma2, it has the same behavior and no explicit mentioning in the docstring.
https://github.com/huggingface/transformers/blob/f1d822ba337499d429f832855622b97d90ac1406/src/transformers/models/gemma2/modeling_gemma2.py#L978-L982
i think at least we should force the docstring to mention this, if making a change is too dangerous at this point.
### Motivation
i didn't expect this behavior and used my data loader, which does the shifting already, as i believe that is what `labels` should mean. as a result, i ended up finetuning a model to predict the next next token, which outputted gibberish.
### Your contribution
- hopefully leaving this issue helps communication across users
- i can make a one line change in the docstring.
- not sure how exactly, but if this potential misunderstanding could be checked, it'd be great. technically, we can check if the labels are already shifted. though i don't know where is the best place for this.