TRL SFTTrainer 0.15 compute_token_accuracy error

I have updated my version of TRL from 0.11 to 0.15. When training LLaMa3.1-8b-Instruct, I get an error:

Traceback (most recent call last):
  File "/home/jovyan/prompt-arithmetics/llama31_instruct_pt.py", line 328, in <module>
    trainer.train()
  File "/home/jovyan/my-conda-envs/tpv/lib/python3.12/site-packages/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/my-conda-envs/tpv/lib/python3.12/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/my-conda-envs/tpv/lib/python3.12/site-packages/transformers/trainer.py", line 3698, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/my-conda-envs/tpv/lib/python3.12/site-packages/trl/trainer/sft_trainer.py", line 453, in compute_loss
    accuracy = compute_token_accuracy(shift_logits, shift_labels)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/my-conda-envs/tpv/lib/python3.12/site-packages/trl/trainer/utils.py", line 1664, in compute_token_accuracy
    correct_predictions = (predictions == labels) & mask
                           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (355) must match the size of tensor b (255) at non-singleton dimension 1

I have traced that the compute_loss method from the transformers Trainer class was overridden by the SFTTraininr in 0.15 version. But I have no idea why this is happening. The problem is probably that the label size differs from the size of the model outputs. I have set max_seq_lenght in SFTConfig to 512.

Here is how I initialize the tokenizer and model (nothing special really):

        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        model.active_adapters = [
            "default"
        ]  # fix because llama has some active adapters for some reason
        model = get_peft_model(model, peft_config=peft_config)

        tokenizer = AutoTokenizer.from_pretrained(
            data_args.data_tokenizer_name_or_path,
            trust_remote_code=True,
            padding_side="right",
        )
        tokenizer.add_special_tokens({"pad_token": "<|reserved_special_token_0|>"})
        model.config.pad_token_id = tokenizer.pad_token_id
        model.generation_config.pad_token_id = tokenizer.pad_token_id

Does anyone have an idea what could be causing the error?

Thank you!

1 Like

Solution with explanation

So, I have realized that this problem persists only when using prompt tuning with SFTTrainer and CausalLM models. This is because prompt tuning prepends trainable embeddings to the input embeddings, and due to the auto-regressive process of forward function the prepended soft-prompt of length 100 was also outputted in the model outputs.

I am not sure if this is the problem of the PEFT library implementation of prompt tuning for CausalLM or whether this is the desired behavior and needs to be fixed on the TRL SFTTrainer side. I managed to create a quick workaround by slicing the first n_vritual_tokens of the outputs if prompt tuning is used in the compute_loss method:

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
  """
  Compute training loss and additionally compute token accuracies
  """
  (loss, outputs) = super().compute_loss(
      model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
  )

  # Compute token accuracy if we have labels and if the model is not using Liger (no logits)
  if "labels" in inputs and not self.args.use_liger:
      if isinstance(model, PeftModel) and model.peft_type == PeftType.PROMPT_TUNING:
          num_virtual_tokens = model.peft_config["default"].num_virtual_tokens
          shift_logits = outputs.logits[..., :-(1+num_virtual_tokens), :].contiguous()
      else:
          shift_logits = outputs.logits[..., :-1, :].contiguous()
      
      shift_labels = inputs["labels"][..., 1:].contiguous()

For some reason, the token accuracy is still really low (compared to using LoRA). I may have to investigate even further, and I will probably open a PR to fix this.

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.