I’m trying to fine tune Whisper with LoRA. When I enable gradient_checkpointing
, I’ll get the following error:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Are LoRA and gradient_checkpointing
inherently incompatible? Or is this a bug?
This is the code I have:
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir=f"./temp", # change to a repo name of your choice
per_device_train_batch_size=8,
gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-3,
warmup_steps=500,
max_steps=40000,
evaluation_strategy="steps",
gradient_checkpointing=True,
fp16=True,
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=10000,
eval_steps=1000,
logging_steps=25,
remove_unused_columns=False, # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
label_names=["labels"], # same reason as above
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=False,
auto_find_batch_size=True,
)
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
if os.path.exists(pytorch_model_path):
os.remove(pytorch_model_path)
return control
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["validation"],
data_collator=data_collator,
# compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
callbacks=[SavePeftModelCallback],
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
trainer.train()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[27], line 1
----> 1 trainer.train()
File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/transformers/trainer.py:1539, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1537 hf_hub_utils.enable_progress_bars()
1538 else:
-> 1539 return inner_training_loop(
1540 args=args,
1541 resume_from_checkpoint=resume_from_checkpoint,
1542 trial=trial,
1543 ignore_keys_for_eval=ignore_keys_for_eval,
1544 )
File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/accelerate/utils/memory.py:136, in find_executable_batch_size.<locals>.decorator(*args, **kwargs)
134 raise RuntimeError("No executable batch size found, reached zero.")
135 try:
--> 136 return function(batch_size, *args, **kwargs)
137 except Exception as e:
138 if should_reduce_batch_size(e):
File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/transformers/trainer.py:1821, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
1818 self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
1820 with self.accelerator.accumulate(model):
-> 1821 tr_loss_step = self.training_step(model, inputs)
1823 if (
1824 args.logging_nan_inf_filter
1825 and not is_torch_tpu_available()
1826 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
1827 ):
1828 # if loss is nan or inf simply add the average of previous logged losses
1829 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/transformers/trainer.py:2677, in Trainer.training_step(self, model, inputs)
2675 scaled_loss.backward()
2676 else:
-> 2677 self.accelerator.backward(loss)
2679 return loss.detach() / self.args.gradient_accumulation_steps
File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/accelerate/accelerator.py:1851, in Accelerator.backward(self, loss, **kwargs)
1849 return
1850 elif self.scaler is not None:
-> 1851 self.scaler.scale(loss).backward(**kwargs)
1852 else:
1853 loss.backward(**kwargs)
File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
477 if has_torch_function_unary(self):
478 return handle_torch_function(
479 Tensor.backward,
480 (self,),
(...)
485 inputs=inputs,
486 )
--> 487 torch.autograd.backward(
488 self, gradient, retain_graph, create_graph, inputs=inputs
489 )
File ~/.conda/envs/lora_3.10/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
195 retain_graph = create_graph
197 # The reason we repeat same the comment below is that
198 # some Python versions print out the first line of a multi-line function
199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
201 tensors, grad_tensors_, retain_graph, create_graph, inputs,
202 allow_unreachable=True, accumulate_grad=True)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn