Solution with explanation
So, I have realized that this problem persists only when using prompt tuning with SFTTrainer and CausalLM models. This is because prompt tuning prepends trainable embeddings to the input embeddings, and due to the auto-regressive process of forward function the prepended soft-prompt of length 100 was also outputted in the model outputs.
I am not sure if this is the problem of the PEFT library implementation of prompt tuning for CausalLM or whether this is the desired behavior and needs to be fixed on the TRL SFTTrainer side. I managed to create a quick workaround by slicing the first n_vritual_tokens of the outputs if prompt tuning is used in the compute_loss method:
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
Compute training loss and additionally compute token accuracies
"""
(loss, outputs) = super().compute_loss(
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
)
# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
if "labels" in inputs and not self.args.use_liger:
if isinstance(model, PeftModel) and model.peft_type == PeftType.PROMPT_TUNING:
num_virtual_tokens = model.peft_config["default"].num_virtual_tokens
shift_logits = outputs.logits[..., :-(1+num_virtual_tokens), :].contiguous()
else:
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = inputs["labels"][..., 1:].contiguous()
For some reason, the token accuracy is still really low (compared to using LoRA). I may have to investigate even further, and I will probably open a PR to fix this.