I was trying to sft on the tatsu-lab/alpaca_farm
dataset using EleutherAI/pythia-14m
(cuz I’m gpu poor). In particular, I want to compute loss on the responses only and not the instructions. So I used the DataCollatorForCompletionOnlyLM
class the way below. This method outputs nan loss on the very first few batches. However, if I don’t use the data collator and compute the loss on the entire output, I never run into any issues. Weird!
I have tried a small number of architectural choices, including using lora on different modules, different lora_alpha, etc. The nan loss is always due to nan gradient on the mlp weights of the last few attention blocks. Interestingly, no other modules encountered nan gradients. I noticed the gradient magnitude generally increases from lower to higher layers. So it is not likely due to gradient explosion. I visualized the attention maps at the last few layers but couldn’t spot any issues.
There are a few other stuff I have checked, such as properly truncating the dataset and setting sequence length to make sure the "<|assistant|>"
token is in every input for the data collator to segment (because if you don’t the data collator will throw a warning, but it shouldn’t affect anything tho because the loss is properly masked).
So not sure why this is happening. Really appreciate if anyone has ideas!
instruction_template = "<|prompter|>"
response_template = "<|assistant|>"
collator = DataCollatorForCompletionOnlyLM(
instruction_template=instruction_template,
response_template=response_template,
tokenizer=tokenizer,
)
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['instruction'])):
if example['input'] != "":
text = f"{instruction_template} {example['instruction'][i]} {example['input'][i]}\n{response_template} {example['output'][i]}"
else:
text = f"{instruction_template} {example['instruction'][i]}\n{response_template} {example['output'][i]}"
output_texts.append(text)
return output_texts
trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
formatting_func=formatting_prompts_func,
data_collator=collator,
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
peft_config=peft_config,
)