Saving/Loading custom model build from varying HF models

I am writing a custom head for a transformer and was wondering how to add the functionality to correctly load/save the model using the HF functions (from_pretrained()…). Specifically, I would like my model class to work with any transformer from the HF hub.

The use case would be:

  1. User provides name of model on hub to from_pretrained() → load model and randomly init head layer
  2. User provides the path to checkpoint → load model and head weights

Toy example:

from transformers import PreTrainedModel, AutoModel, AutoConfig
import torch.nn as nn

class MyCustomModel(PreTrainedModel):
    def __init__(self, config, transformer_model_name_or_path, num_classes):
        super(MyCustomModel, self).__init__(config)
        self.transformer = AutoModel.from_pretrained(transformer_model_name_or_path, config=config)
        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.Linear(config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

    def _init_weights(self, module):
from transformers import AutoConfig
config = AutoConfig.from_pretrained("sentence-transformers/all-distilroberta-v1")
model = MyCustomModel.from_pretrained("sentence-transformers/all-distilroberta-v1", config=config, transformer_model_name_or_path="sentence-transformers/all-distilroberta-v1", num_classes=2)

This code is running but does not load the weights of the distilroberta model correctly. The issues is that I would like to use different models like BERT, RoBERTa which is why I am using AutoClasses but this is then resulting in the weights not being loaded correctly.

I managed to build this which works for loading any kind of hf model into my custom framework:

class MyCustomModel(PreTrainedModel):
    # Important so set correct AutoConfig, PretrainedConfig wont work for use in AutoModel.from_config
    config_class = AutoConfig
    # This has to be equal to the name of the encoder object i.e. self.my_encoder = AutoModel.from_config(config)
    base_model_prefix = "my_encoder"
    def __init__(self, config, num_classes):
        super(MyCustomModel, self).__init__(config)
        self.my_encoder = AutoModel.from_config(config)
        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.Linear(config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.my_encoder(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

    def _init_weights(self, module):
        return self.my_encoder._init_weights(module)

This seems to resolve the warnings that I got. Could someone confirm that this is in fact a robust way of solving this issue ´?