Resuming Training from Checkpoints Stored on Hugging Face Hub (without downloading manually)

Hi community,

I am currently working on a project where I train a model (e.g., bert-base-uncased) and push it to my private Hugging Face Hub repository, along with its optimizer state, using the Trainer class.

In my workflow, I save model checkpoints and later want to resume training from a specific checkpoint (which includes the model and optimizer state) that I previously pushed to the Hub.

Currently, I am manually downloading the checkpoint files from the Hub using hf_hub_download() and then using them to resume training. Here’s an outline of what I’ve done:

  1. Training and Saving Checkpoints: I save checkpoints during training with Trainer, and I specify hub_strategy="all_checkpoints" to push the model and optimizer state to the Hub.
  2. Fetching Checkpoint: To resume training, I use hf_hub_download() to download the checkpoint files (model, optimizer, etc.) from the Hub to my local machine.
  3. Resuming Training: I use the locally fetched checkpoint files (model and optimizer state) to resume training using the Trainer class by specifying resume_from_checkpoint.

However, this process feels a bit cumbersome, as I have to manually fetch the checkpoint files (my fetch_checkpoint() function bellow). I was wondering if there’s a simpler way to resume training directly from a checkpoint stored on the Hub (without manually downloading files) — especially when the checkpoint contains the model and optimizer states.

Is there a built-in function or more streamlined way to directly resume training from a checkpoint on the Hugging Face Hub?

Here’s my current code:

from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    TrainingArguments,
    Trainer,
)
from datasets import Dataset
from huggingface_hub import HfApi, HfFolder, login, hf_hub_download, HfFileSystem

repo_name = "my-user/my-private-model"
login("your_hf_token")
token = HfFolder.get_token()


def get_dataset():
    model_name = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(model_name)

    # Create Dummy Dataset
    data = {
        "text": ["I love this!", "This is bad.", "Amazing!", "Not great...", "Superb!"],
        "labels": [1, 0, 1, 0, 1],  # 1 = Positive, 0 = Negative
    }
    dataset = Dataset.from_dict(data)

    # Tokenize Data
    def tokenize_batch(examples):
        return tokenizer(
            examples["text"], padding="max_length", truncation=True, max_length=32
        )

    dataset = dataset.map(tokenize_batch, batched=True)
    return dataset


def save():
    # Authenticate with Hugging Face Hub
    api = HfApi()
    api.create_repo(repo_name, private=True, token=token, exist_ok=True)

    # Load the model and tokenizer from public Hugging Face Hub
    model_name = "bert-base-uncased"
    model = BertForSequenceClassification.from_pretrained(model_name)
    tokenizer = BertTokenizer.from_pretrained(model_name)

    dataset = get_dataset()

    # Train
    training_args = TrainingArguments(
        output_dir="./my-experiment",
        per_device_train_batch_size=1,
        num_train_epochs=2,
        logging_steps=1,
        save_strategy="epoch",
        push_to_hub=True,  # Push the model to the private Hugging Face Hub
        hub_model_id=repo_name,
        report_to="none",
        hub_strategy="all_checkpoints",  # This will save the model as a checkpoint including the optimizer state
        save_steps=5,
    )
    trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
    trainer.train()

    # Push the tokenizer to the private Hugging Face Hub
    tokenizer.push_to_hub(repo_name, private=True)


def fetch_checkpoint():
    # To resume training, you need to get the previous optimizer state,
    # use the following code to download the checkpoint files locally
    checkpoint_folder = "checkpoint-5"
    local_checkpoint_folder = "fetched_checkpoint"

    # List the files in the checkpoint folder
    fs = HfFileSystem()
    files_in_folder = fs.ls(f"{repo_name}/{checkpoint_folder}", detail=False)
    files_in_folder = [file.split("/")[-1] for file in files_in_folder]

    # Download the files from the subfolder
    downloaded_files = []
    for file_name in files_in_folder:
        file_path = hf_hub_download(
            repo_id=repo_name,
            filename=file_name,
            subfolder=checkpoint_folder,
            local_dir=local_checkpoint_folder,
        )
        downloaded_files.append(file_path)


def resume_training():
    # Resume training from the local checkpoint
    checkpoint_folder = "checkpoint-5"
    local_checkpoint_folder = "fetched_checkpoint"
    checkpoint_path = f"{local_checkpoint_folder}/{checkpoint_folder}"
    dataset = get_dataset()

    model = BertForSequenceClassification.from_pretrained(checkpoint_path)
    training_args = TrainingArguments(
        output_dir="./my-resume-experiment",
        per_device_train_batch_size=1,
        num_train_epochs=3,
        logging_steps=1,
        save_strategy="epoch",
        hub_model_id=repo_name,
        report_to="none",
        resume_from_checkpoint=checkpoint_path,
    )
    trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
    trainer.train(resume_from_checkpoint=checkpoint_path)


if __name__ == "__main__":
    save()
    fetch_checkpoint()
    resume_training()

2 Likes

I think the Trainer’s original method and the method using from_pretrained, which is often used even when simply using a model, are simple.

Thanks for your response! I appreciate the links.

I might be misunderstanding something, but from what I see, the examples you shared assume that the checkpoint is stored locally. My question is more about whether Trainer can resume training directly from a checkpoint stored on the Hugging Face Hub, without manually downloading the files first(e.g., using hf_hub_download).

In my above example, I’d like to resume training from checkpoint-5 located on my Hub at my-user/my-private-model/checkpoint-5.

From my testing, it seems like Trainer’s resume_from_checkpoint only works when the checkpoint is already on the local filesystem, which is why I had to manually fetch it from the Hub before resuming. If there’s a way to do this more seamlessly, I’d love to know!

Let me know if I’m missing something - maybe I’m overcomplicating it. Thanks again! :blush:

1 Like

HF libraries generally have trouble handling subfolders, but it seems that get_last_checkpoint can be used for offline use. I wonder if it can also be used for online use…?

That’s a good point! I looked at the code for get_last_checkpoint, and since it uses os.listdir, it seems to be designed to work only locally. So, unless I’m missing something, it wouldn’t help for fetching checkpoints directly from the Hub.

Also, I’d like to be able to retrieve any checkpoint, not just the last one.

At this point, I’d just like confirmation that my way of fetching checkpoints from the Hub is the correct one and that there isn’t a more seamless, built-in way to do it. I’m happy to hear any advice or discuss alternative approaches if I’ve overlooked something!

Looking forward to your thoughts! :blush:

1 Like

I see. So, the current method is probably correct. If you want to deviate from the 1 repo 1 model (revision management is possible), it is more reliable to manage it yourself. If you use hf_hub_download or snapshot_download, there should not be much difference in speed from_pretrained…

Thanks for the clarification! I’ll stick with my approach then.
Appreciate your help! :rocket:

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.