Why are there only 3 steps per epoch when the dataset has 2500 rows and batch_size is 1

Why does 1 epoch = 3 steps only when my dataset has 2500 rows and I’m using per_device_train_batch_size=1? I would expect 2500 steps for an epoch. I’m using SFTTrainer and a HF hub dataset.

Here’s the code, this is reproducible in a Colab on a T4 ( NOTE: This is a copy-paste of an official Gemma 7B example with max_steps and gradient_accumulation_steps commented out).:

import os
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

# %

!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.8.2
!pip3 install -q -U trl==0.7.10
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.38.1

# %

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer

model_id = "google/gemma-7b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])

# %

from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

# %

from datasets import load_dataset

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

# %

import transformers
from trl import SFTTrainer

def formatting_func(example):
    text = f"Quote: {example['quote'][0]}\nAuthor: {example['author'][0]}"
    return [text]

trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        # gradient_accumulation_steps=4,
        # warmup_steps=2,
        # max_steps=10,
        learning_rate=2e-4,
        fp16=True,
        # logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)

print(f"Num devices: {torch.cuda.device_count()}")
print(f"Dataset len: {len(data['train'])}")
trainer.train()

The screenshot below shows that my dataset has ~2500 rows but trainer.train() does 3 steps per epoch (and num_epochs=3 by default, so we see 9 steps for 3 epochs). Docs online suggest Trainer should do a full iteration of the dataset per epoch.