Registering custom model and config to AutoModel and AutoConfig

Hi everyone,

I am trying to create a custom model on top of pretrained model and save it, and use it as pre-trained model for other use case. Here is my code:

encoder_config={
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522,
  "encoder_width": 768,
}

Custom Encoder Config:

class TextEncoderConfig(PretrainedConfig):
    model_type="TextEncoder"
    '''
        hidden_size = 768
        encoder_width = 768
        num_attention_head = 12
        num_hidden_layers = 12
    '''
    def __init__(self, hidden_size= 512, encoder_width= 512, vocab_size= 30522, 
                 max_position_embeddings=512,
                 attention_probs_dropout_prob=0.1,
                 hidden_act = "gelu",
                 hidden_dropout_prob = 0.1,
                 initializer_range = 0.02,
                 intermediate_size=3072,
                 layer_norm_eps=1e-12,
                 num_attention_heads=8,
                 num_hidden_layers=8,
                 pad_token_id=0,
                 type_vocab_size=2,
                 **kwargs):
        self.hidden_size=hidden_size
        self.encoder_width = encoder_width
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.initializer_range = initializer_range
        self.intermediate_size = intermediate_size
        self.layer_norm_eps = layer_norm_eps
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.pad_token_id = pad_token_id
        self.type_vocab_size = type_vocab_size
        super().__init__(**kwargs)

encoder_config = TextEncoderConfig()

Custom Model:

class TextEncoder(PreTrainedModel):
    config_class = TextEncoderConfig
    def __init__(self, config):
        super().__init__(config)
        self.tokenizer = BertTokenizer.from_pretrained('google/bert_uncased_L-8_H-512_A-8')
        self.tokenizer.add_special_tokens({'bos_token':'[DEC]'})
        self.tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
        self.tokenizer.enc_token_id = self.tokenizer.additional_special_tokens_ids[0]
        self.encoder_config = config
        # pretrained bert encoder
        self.encoder = BertModel.from_pretrained('google/bert_uncased_L-8_H-512_A-8', config=self.encoder_config, add_pooling_layer=False, ignore_mismatched_sizes=True)

    def forward(self, in_sen):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        tokenizer_output = self.tokenizer(in_sen, padding='max_length', truncation=True, max_length=16, return_tensors='pt')
        in_seq = tokenizer_output.input_ids.to(device)
        in_mask = tokenizer_output.attention_mask.to(device)
        text_emb = self.encoder(input_ids=in_seq, attention_mask=in_mask)
        return text_emb

text_encoder = TextEncoder(encoder_config).to(device);

Now following this documentation: Sharing custom models

I tried saving the custom config and model:

text_encoder.save_pretrained("CustomModels/TextEncoder")
AutoConfig.register("TextEncoder", TextEncoderConfig)
AutoModel.register(TextEncoderConfig, TextEncoder)

After registering the model, I tried loading it:

text_encoder = AutoModel.from_pretrained("CustomModels/TextEncoder")

I get the following error:

KeyError                                  Traceback (most recent call last)
Cell In[3], line 1
----> 1 text_encoder = AutoModel.from_pretrained("CustomModels/TextEncoder")

File /data/anaconda3/envs/bishwa_l3/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:441, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    438     if kwargs_copy.get("torch_dtype", None) == "auto":
    439         _ = kwargs_copy.pop("torch_dtype")
--> 441     config, kwargs = AutoConfig.from_pretrained(
    442         pretrained_model_name_or_path,
    443         return_unused_kwargs=True,
    444         trust_remote_code=trust_remote_code,
    445         **hub_kwargs,
    446         **kwargs_copy,
    447     )
    448 if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
    449     if not trust_remote_code:

File /data/anaconda3/envs/bishwa_l3/lib/python3.11/site-packages/transformers/models/auto/configuration_auto.py:939, in AutoConfig.from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
    936 elif "model_type" in config_dict:
    938     print(CONFIG_MAPPING)
--> 939     print(CONFIG_MAPPING[config_dict["model_type"]])
    940     print(CONFIG_MAPPING)
    941     config_class = CONFIG_MAPPING[config_dict["model_type"]]

File /data/anaconda3/envs/bishwa_l3/lib/python3.11/site-packages/transformers/models/auto/configuration_auto.py:643, in _LazyConfigMapping.__getitem__(self, key)
    641     return self._extra_content[key]
    642 if key not in self._mapping:
--> 643     raise KeyError(key)
    644 value = self._mapping[key]
    645 module_name = model_type_to_module_name(key)

KeyError: 'TextEncoder'

I was curious why the

TextEncoder` was not found though I have it there in the `model_type` and on printing the line `938     print(CONFIG_MAPPING)
--> 939     print(CONFIG_MAPPING[config_dict["model_type"]])

I get this result:

_LazyConfigMapping()
<class '__main__.TextEncoderConfig'>
_LazyConfigMapping()
_LazyConfigMapping()
<class 'transformers.models.bert.configuration_bert.BertConfig'>
_LazyConfigMapping()

Here I am surprised how original BertConfig is being stored?

Also, the print line doesn’t work if I try to solely run text_encoder = AutoModel.from_pretrained("CustomModels/TextEncoder") without instantiating the textEncoder class.

Any help would be appreciated.

@JustSaX, @sgugger @ nielsr do you have any suggestions for this problem?