SFTTrainer checkpointing

Hey I’m trying to finetune Llama 2 and I can’t see where the checkpoints are getting saved. I am using the following code:

base_model_name = "meta-llama/Llama-2-7b-hf"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

device_map = "auto"

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map=device_map,
    trust_remote_code=True,
    use_auth_token=True
)
base_model.config.use_cache = False

# More info: https://github.com/huggingface/transformers/pull/24906
base_model.config.pretraining_tp = 1 

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

output_dir = "./Llama-2-7b-hf-qlora"

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    logging_steps=5,
    max_steps=400,
    evaluation_strategy="steps", # Evaluate the model every logging step
    logging_dir="./logs",        # Directory for storing logs
    save_strategy="steps",       # Save the model checkpoint every logging step
    eval_steps=5,               # Evaluate and save checkpoints every 10 steps
    do_eval=True                 # Perform evaluation at the end of training
)
class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        kwargs["model"].save_pretrained(checkpoint_path)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


callbacks = [PeftSavingCallback()]
max_seq_length = 512
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,  # Add this line
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_args,
    callbacks=callbacks
)


trainer.train()

(I added in the callback stuff based on this guide Supervised Fine-tuning Trainer)
How do I get a checkpoint saved every 5/10 steps?

Worked this out…Fairly simple in the end: just adding save_steps to TrainingArguments does the trick!

Though you have answered the question, just adding a snippet for sake of completeness.


# Set training parameters
training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,   # saves the model at specified step....
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="tensorboard"
)


trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=packing,
)

Are you able to resume training from the checkpoint or able to load the model from checkpoint. I could not do either. I used SFTTrainer? Once the model is trained, I could not find config.json file as well.

do this: trainer.train(resume_from_checkpoint=true)

4 Likes

Thank you. The actual problem was that, I was using an older version of transformer. The problem is solved in the second last release.

Hey, I am trying to run mistral model from checkpoint. But the saving checkpoints are missing a file like pytorch_model.bin.index.json is not saving. Any idea of that?