Why does 1 epoch = 3 steps only when my dataset has 2500 rows and I’m using per_device_train_batch_size=1
? I would expect 2500 steps for an epoch. I’m using SFTTrainer
and a HF hub dataset.
Here’s the code, this is reproducible in a Colab on a T4 ( NOTE: This is a copy-paste of an official Gemma 7B example with max_steps
and gradient_accumulation_steps
commented out).:
import os
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
# %
!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.8.2
!pip3 install -q -U trl==0.7.10
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.38.1
# %
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
model_id = "google/gemma-7b"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])
# %
from peft import LoraConfig
lora_config = LoraConfig(
r=8,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
# %
from datasets import load_dataset
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
# %
import transformers
from trl import SFTTrainer
def formatting_func(example):
text = f"Quote: {example['quote'][0]}\nAuthor: {example['author'][0]}"
return [text]
trainer = SFTTrainer(
model=model,
train_dataset=data["train"],
args=transformers.TrainingArguments(
per_device_train_batch_size=1,
# gradient_accumulation_steps=4,
# warmup_steps=2,
# max_steps=10,
learning_rate=2e-4,
fp16=True,
# logging_steps=1,
output_dir="outputs",
optim="paged_adamw_8bit"
),
peft_config=lora_config,
formatting_func=formatting_func,
)
print(f"Num devices: {torch.cuda.device_count()}")
print(f"Dataset len: {len(data['train'])}")
trainer.train()
The screenshot below shows that my dataset has ~2500 rows but trainer.train()
does 3 steps per epoch (and num_epochs=3
by default, so we see 9 steps for 3 epochs). Docs online suggest Trainer should do a full iteration of the dataset per epoch.