How is the number of steps calculated in trl's SFTTrainer under multiple-GPU?

I am trying to finetune the model on the HH-RLHF dataset with 161k rows of training data. I use this command to run torchrun --nnodes 1 --nproc_per_node 8 sft.py, which from what I understand, uses all 8 GPUs. The batch size per GPU and gradient accumulation steps are set to 4 and 1. Therefore, the number of steps should be around 161k / (8 * 4 * 1) = 5k steps. However, the trainer only train the model for 40 steps. I wonder how this number comes up?

Below is the code for sft.py:

import argparse
import os

import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, logging, set_seed
from trl import SFTTrainer


os.environ["WANDB_PROJECT"] = "llama-hh-rlhf"


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="huggyllama/llama-7b")

    parser.add_argument("--max_steps", type=int, default=-1)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--num_train_epochs", type=int, default=1)

    parser.add_argument("--learning_rate", type=float, default=2e-5)
    parser.add_argument("--lr_scheduler_type", type=str, default="linear")
    parser.add_argument("--num_warmup_steps", type=int, default=0)
    parser.add_argument("--weight_decay", type=float, default=0.05)

    parser.add_argument("--lora_r", type=int, default=16)
    parser.add_argument("--lora_alpha", type=int, default=32)
    parser.add_argument("--lora_dropout", type=float, default=0.05)

    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--fp16", action="store_true", default=False)
    parser.add_argument("--bf16", action="store_true", default=True)
    parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--num_workers", type=int, default=None)
    parser.add_argument("--output_dir", type=str, default="./llama-7b-sft")
    parser.add_argument("--log_freq", default=1, type=int)
    parser.add_argument("--eval_freq", default=1, type=int)
    parser.add_argument("--save_freq", default=1, type=int)

    return parser.parse_args()


def run_training(args):
    print("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        load_in_8bit=True,
        device_map={"": Accelerator().process_index}
    )

    print("Loading dataset...")
    dataset = load_dataset("Anthropic/hh-rlhf")
    train_dataset = dataset["train"]
    eval_dataset = dataset["test"]

    dataset_length = len(train_dataset)
    effective_batch_size = 8 * args.batch_size * args.gradient_accumulation_steps
    num_train_epochs = args.num_train_epochs
    num_steps = (dataset_length // effective_batch_size) * num_train_epochs
    print("dataset_length:", dataset_length)
    print("num_steps:", num_steps)

    print("Setting up training...")
    peft_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        dataloader_drop_last=True,
        num_train_epochs=args.num_train_epochs,
        evaluation_strategy="steps",
        eval_steps=args.eval_freq,
        save_steps=args.save_freq,
        logging_steps=args.log_freq,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        lr_scheduler_type=args.lr_scheduler_type,
        warmup_steps=args.num_warmup_steps,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        gradient_checkpointing=args.gradient_checkpointing,
        fp16=args.fp16,
        bf16=args.bf16,
        weight_decay=args.weight_decay,
        run_name="llama-7b-sft",
        report_to="wandb",
        ddp_find_unused_parameters=False,
    )
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        dataset_text_field="chosen",
        peft_config=peft_config,
    )

    print("Training...")
    trainer.train()

    print("Saving model...")
    trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))


if __name__ == "__main__":
    args = get_args()

    set_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)

    logging.set_verbosity_error()

    run_training(args)
2 Likes

I experienced a similar issue where the actual number of training steps was far below the ā€œpredictedā€ number of training steps.

1 Like

facing the same issueā€¦