How do I fix this error when training in TRL with QLora and PPO?

When running a step of PPO with TRL i get this error:

from trl import PPOConfig
import torch
config = PPOConfig(
learning_rate=1.41e-5,
batch_size=16,
mini_batch_size=1,
gradient_accumulation_steps = 16,
ppo_epochs = 1
)
from transformers import AutoTokenizer

reward_model = score_output_list

ppo_trainer = PPOTrainer(
model=model,
ref_model=ref_model,
config=config,
dataset=dataset,
tokenizer=tokenizer,
optimizer=“adamw_8bit”,
)
generation_kwargs = {
“min_length”: -1,
“top_k”: 0.0,
“top_p”: 1.0,
“do_sample”: True,
“max_new_tokens”: 50,
“pad_token_id”: tokenizer.eos_token_id,
}

from trl import PPOConfig
import torch
config = PPOConfig(
learning_rate=1.41e-5,
batch_size=16,
mini_batch_size=1,
gradient_accumulation_steps = 16,
ppo_epochs = 1
)
from transformers import AutoTokenizer

reward_model = score_output_list

ppo_trainer = PPOTrainer(
model=model,
ref_model=ref_model,
config=config,
dataset=dataset,
tokenizer=tokenizer,
optimizer=“adamw_8bit”,
)
generation_kwargs = {
“min_length”: -1,
“top_k”: 0.0,
“top_p”: 1.0,
“do_sample”: True,
“max_new_tokens”: 50,
“pad_token_id”: tokenizer.eos_token_id,
}


RuntimeError Traceback (most recent call last)
Cell In[20], line 22
20 print(rewards)
21 #### Run PPO step
—> 22 stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
23 ppo_trainer.log_stats(stats, batch, rewards)
25 #### Save model

File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
—> 79 return func(*args, **kwds)

File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:798, in PPOTrainer.step(self, queries, responses, scores, response_masks)
795 with self.accelerator.accumulate(self.model):
796 model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
→ 798 logprobs, logits, vpreds, _ = self.batched_forward_pass(
799 self.model,
800 mini_batch_dict[“queries”],
801 mini_batch_dict[“responses”],
802 model_inputs,
803 return_logits=True,
804 )
805 train_stats = self.train_minibatch(
806 mini_batch_dict[“logprobs”],
807 mini_batch_dict[“values”],
(…)
813 mini_batch_dict[“returns”],
814 )
815 all_stats.append(train_stats)

File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
—> 79 return func(*args, **kwds)

File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:994, in PPOTrainer.batched_forward_pass(self, model, queries, responses, model_inputs, return_logits, response_masks)
992 if response_masks is not None:
993 response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
→ 994 logits, _, values = model(**input_kwargs)
996 if self.is_encoder_decoder:
997 input_ids = input_kwargs[“decoder_input_ids”]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
→ 1511 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1561, in Module._call_impl(self, *args, **kwargs)
1558 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1559 args = bw_hook.setup_input_hook(args)
→ 1561 result = forward_call(*args, **kwargs)
1562 if _global_forward_hooks or self._forward_hooks:
1563 for hook_id, hook in (
1564 *_global_forward_hooks.items(),
1565 *self._forward_hooks.items(),
1566 ):
1567 # mark that always called hook is run

File /opt/conda/lib/python3.10/site-packages/trl/models/modeling_value_head.py:171, in AutoModelForCausalLMWithValueHead.forward(self, input_ids, past_key_values, attention_mask, **kwargs)
168 if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == “PREFIX_TUNING”:
169 kwargs.pop(“past_key_values”)
→ 171 base_model_output = self.pretrained_model(
172 input_ids=input_ids,
173 attention_mask=attention_mask,
174 **kwargs,
175 )
177 last_hidden_state = base_model_output.hidden_states[-1]
178 lm_logits = base_model_output.logits

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
→ 1511 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don’t have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File /opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py:882, in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
869 def PeftModelForCausalLM_fast_forward(
870 self,
871 input_ids=None,
(…)
880 **kwargs,
881 ):
→ 882 return self.base_model(
883 input_ids=input_ids,
884 causal_mask=causal_mask,
885 attention_mask=attention_mask,
886 inputs_embeds=inputs_embeds,
887 labels=labels,
888 output_attentions=output_attentions,
889 output_hidden_states=output_hidden_states,
890 return_dict=return_dict,
891 **kwargs,
892 )

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
→ 1511 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don’t have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File /opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161, in BaseTuner.forward(self, *args, **kwargs)
160 def forward(self, *args: Any, **kwargs: Any):
→ 161 return self.model.forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
→ 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/unsloth/models/mistral.py:212, in MistralForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
204 outputs = LlamaModel_fast_forward_inference(
205 self,
206 input_ids,
(…)
209 attention_mask = attention_mask,
210 )
211 else:
→ 212 outputs = self.model(
213 input_ids=input_ids,
214 causal_mask=causal_mask,
215 attention_mask=attention_mask,
216 position_ids=position_ids,
217 past_key_values=past_key_values,
218 inputs_embeds=inputs_embeds,
219 use_cache=use_cache,
220 output_attentions=output_attentions,
221 output_hidden_states=output_hidden_states,
222 return_dict=return_dict,
223 )
224 pass
226 hidden_states = outputs[0]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
→ 1511 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don’t have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
→ 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py:680, in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
677 hidden_states = layer_outputs[0]
679 else:
→ 680 layer_outputs = decoder_layer(
681 hidden_states,
682 causal_mask=causal_mask,
683 attention_mask=attention_mask,
684 position_ids=position_ids,
685 past_key_value=past_key_value,
686 output_attentions=output_attentions,
687 use_cache=use_cache,
688 padding_mask=padding_mask,
689 )
690 hidden_states = layer_outputs[0]
691 pass

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
→ 1511 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don’t have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module..new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
→ 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py:423, in LlamaDecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs)
412 hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
413 hidden_states, self_attn_weights, present_key_value = self.self_attn(
414 hidden_states=hidden_states,
415 causal_mask=causal_mask,
(…)
421 padding_mask=padding_mask,
422 )
→ 423 hidden_states += residual
425 # Fully Connected
426 residual = hidden_states

RuntimeError: Output 0 of LoRA_WBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.