How to load a pretrained custom model using `from_pretrained`

Hi there,

I wanted to create a custom model that includes a transformer and save it using the save_pretrained function after training for a few epochs. I would then want to load it in a different notebook using the from_pretrained function for inference.

Suppose I follow this guide and created a custom model named CustomModel with something like:

class CustomModel(PreTrainedModel):
     def __init__(self, config, transformer_model_name, n_dims=1000, n_factors=50, n_classes=10):
        super().__init__(config)
        self.embs = nn.Embedding(n_dims, n_factors)
        self.text_transformer = AutoModelForSequenceClassification.from_pretrained(
            transformer_model_name, num_labels=512)
        self.linear_layers = nn.Sequential(
            nn.Linear(n_factors+512, 256, bias=False),
            nn.LeakyReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(),
            nn.Linear(256, n_classes)
        )

Suppose I have already trained and saved the model, can I later use CustomModel.from_pretrained(model_dir) to load the trained model in a different notebook? I tried something along these lines, but got an AttributeError: 'NoneType' object has no attribute 'from_pretrained' error when I called from_pretrained.

I’m eager to figure out what I did wrong and what would be the best approach to go about this. Any suggestions would be much appreciated.

Thank you very much!

Yes you can inherit from PreTrainedModel to inherit methods like from_pretrained, save_pretrained and push_to_hub.

Alternatively, you can leverage the PyTorchModelHubMixin class available in the huggingface_hub library. This allows you to get the same functionality:

from torch import nn
from huggingface_hub import PyTorchModelHubMixin

class CustomModel(nn.Module, PyTorchModelHubMixin):
          ...
2 Likes

Thank you, this is exactly what I was looking for!

nice,thank

The reason for this error is that the class definition is missing the config_class

class CustomModel(PreTrainedModel):
    config_class = CustomConfig

    def __init__(self, config, ...
    ...

There is an example in the guide in the original question.