I am trying to finetune the model on the HH-RLHF dataset with 161k rows of training data. I use this command to run torchrun --nnodes 1 --nproc_per_node 8 sft.py
, which from what I understand, uses all 8 GPUs. The batch size per GPU and gradient accumulation steps are set to 4 and 1. Therefore, the number of steps should be around 161k / (8 * 4 * 1) = 5k steps. However, the trainer only train the model for 40 steps. I wonder how this number comes up?
Below is the code for sft.py
:
import argparse
import os
import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, logging, set_seed
from trl import SFTTrainer
os.environ["WANDB_PROJECT"] = "llama-hh-rlhf"
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="huggyllama/llama-7b")
parser.add_argument("--max_steps", type=int, default=-1)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--lr_scheduler_type", type=str, default="linear")
parser.add_argument("--num_warmup_steps", type=int, default=0)
parser.add_argument("--weight_decay", type=float, default=0.05)
parser.add_argument("--lora_r", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.05)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--fp16", action="store_true", default=False)
parser.add_argument("--bf16", action="store_true", default=True)
parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num_workers", type=int, default=None)
parser.add_argument("--output_dir", type=str, default="./llama-7b-sft")
parser.add_argument("--log_freq", default=1, type=int)
parser.add_argument("--eval_freq", default=1, type=int)
parser.add_argument("--save_freq", default=1, type=int)
return parser.parse_args()
def run_training(args):
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
load_in_8bit=True,
device_map={"": Accelerator().process_index}
)
print("Loading dataset...")
dataset = load_dataset("Anthropic/hh-rlhf")
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
dataset_length = len(train_dataset)
effective_batch_size = 8 * args.batch_size * args.gradient_accumulation_steps
num_train_epochs = args.num_train_epochs
num_steps = (dataset_length // effective_batch_size) * num_train_epochs
print("dataset_length:", dataset_length)
print("num_steps:", num_steps)
print("Setting up training...")
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
training_args = TrainingArguments(
output_dir=args.output_dir,
dataloader_drop_last=True,
num_train_epochs=args.num_train_epochs,
evaluation_strategy="steps",
eval_steps=args.eval_freq,
save_steps=args.save_freq,
logging_steps=args.log_freq,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
warmup_steps=args.num_warmup_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=args.gradient_checkpointing,
fp16=args.fp16,
bf16=args.bf16,
weight_decay=args.weight_decay,
run_name="llama-7b-sft",
report_to="wandb",
ddp_find_unused_parameters=False,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="chosen",
peft_config=peft_config,
)
print("Training...")
trainer.train()
print("Saving model...")
trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
if __name__ == "__main__":
args = get_args()
set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
logging.set_verbosity_error()
run_training(args)