Hey I’m trying to finetune Llama 2 and I can’t see where the checkpoints are getting saved. I am using the following code:
base_model_name = "meta-llama/Llama-2-7b-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
device_map = "auto"
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=bnb_config,
device_map=device_map,
trust_remote_code=True,
use_auth_token=True
)
base_model.config.use_cache = False
# More info: https://github.com/huggingface/transformers/pull/24906
base_model.config.pretraining_tp = 1
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
output_dir = "./Llama-2-7b-hf-qlora"
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_steps=5,
max_steps=400,
evaluation_strategy="steps", # Evaluate the model every logging step
logging_dir="./logs", # Directory for storing logs
save_strategy="steps", # Save the model checkpoint every logging step
eval_steps=5, # Evaluate and save checkpoints every 10 steps
do_eval=True # Perform evaluation at the end of training
)
class PeftSavingCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
kwargs["model"].save_pretrained(checkpoint_path)
if "pytorch_model.bin" in os.listdir(checkpoint_path):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
callbacks = [PeftSavingCallback()]
max_seq_length = 512
trainer = SFTTrainer(
model=base_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset, # Add this line
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_args,
callbacks=callbacks
)
trainer.train()
(I added in the callback stuff based on this guide Supervised Fine-tuning Trainer)
How do I get a checkpoint saved every 5/10 steps?