Hi community,
I am currently working on a project where I train a model (e.g., bert-base-uncased
) and push it to my private Hugging Face Hub repository, along with its optimizer state, using the Trainer
class.
In my workflow, I save model checkpoints and later want to resume training from a specific checkpoint (which includes the model and optimizer state) that I previously pushed to the Hub.
Currently, I am manually downloading the checkpoint files from the Hub using hf_hub_download()
and then using them to resume training. Here’s an outline of what I’ve done:
- Training and Saving Checkpoints: I save checkpoints during training with
Trainer
, and I specifyhub_strategy="all_checkpoints"
to push the model and optimizer state to the Hub. - Fetching Checkpoint: To resume training, I use
hf_hub_download()
to download the checkpoint files (model, optimizer, etc.) from the Hub to my local machine. - Resuming Training: I use the locally fetched checkpoint files (model and optimizer state) to resume training using the
Trainer
class by specifyingresume_from_checkpoint
.
However, this process feels a bit cumbersome, as I have to manually fetch the checkpoint files (my fetch_checkpoint()
function bellow). I was wondering if there’s a simpler way to resume training directly from a checkpoint stored on the Hub (without manually downloading files) — especially when the checkpoint contains the model and optimizer states.
Is there a built-in function or more streamlined way to directly resume training from a checkpoint on the Hugging Face Hub?
Here’s my current code:
from transformers import (
BertForSequenceClassification,
BertTokenizer,
TrainingArguments,
Trainer,
)
from datasets import Dataset
from huggingface_hub import HfApi, HfFolder, login, hf_hub_download, HfFileSystem
repo_name = "my-user/my-private-model"
login("your_hf_token")
token = HfFolder.get_token()
def get_dataset():
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
# Create Dummy Dataset
data = {
"text": ["I love this!", "This is bad.", "Amazing!", "Not great...", "Superb!"],
"labels": [1, 0, 1, 0, 1], # 1 = Positive, 0 = Negative
}
dataset = Dataset.from_dict(data)
# Tokenize Data
def tokenize_batch(examples):
return tokenizer(
examples["text"], padding="max_length", truncation=True, max_length=32
)
dataset = dataset.map(tokenize_batch, batched=True)
return dataset
def save():
# Authenticate with Hugging Face Hub
api = HfApi()
api.create_repo(repo_name, private=True, token=token, exist_ok=True)
# Load the model and tokenizer from public Hugging Face Hub
model_name = "bert-base-uncased"
model = BertForSequenceClassification.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
dataset = get_dataset()
# Train
training_args = TrainingArguments(
output_dir="./my-experiment",
per_device_train_batch_size=1,
num_train_epochs=2,
logging_steps=1,
save_strategy="epoch",
push_to_hub=True, # Push the model to the private Hugging Face Hub
hub_model_id=repo_name,
report_to="none",
hub_strategy="all_checkpoints", # This will save the model as a checkpoint including the optimizer state
save_steps=5,
)
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
trainer.train()
# Push the tokenizer to the private Hugging Face Hub
tokenizer.push_to_hub(repo_name, private=True)
def fetch_checkpoint():
# To resume training, you need to get the previous optimizer state,
# use the following code to download the checkpoint files locally
checkpoint_folder = "checkpoint-5"
local_checkpoint_folder = "fetched_checkpoint"
# List the files in the checkpoint folder
fs = HfFileSystem()
files_in_folder = fs.ls(f"{repo_name}/{checkpoint_folder}", detail=False)
files_in_folder = [file.split("/")[-1] for file in files_in_folder]
# Download the files from the subfolder
downloaded_files = []
for file_name in files_in_folder:
file_path = hf_hub_download(
repo_id=repo_name,
filename=file_name,
subfolder=checkpoint_folder,
local_dir=local_checkpoint_folder,
)
downloaded_files.append(file_path)
def resume_training():
# Resume training from the local checkpoint
checkpoint_folder = "checkpoint-5"
local_checkpoint_folder = "fetched_checkpoint"
checkpoint_path = f"{local_checkpoint_folder}/{checkpoint_folder}"
dataset = get_dataset()
model = BertForSequenceClassification.from_pretrained(checkpoint_path)
training_args = TrainingArguments(
output_dir="./my-resume-experiment",
per_device_train_batch_size=1,
num_train_epochs=3,
logging_steps=1,
save_strategy="epoch",
hub_model_id=repo_name,
report_to="none",
resume_from_checkpoint=checkpoint_path,
)
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
trainer.train(resume_from_checkpoint=checkpoint_path)
if __name__ == "__main__":
save()
fetch_checkpoint()
resume_training()