Hello, I am newer to HuggingFace and wanted to create my own
nn.Module class that used RoBERTa as an encoder. I am also hoping that I would be able to use it with HuggingFace’s
Trainer class. Looking at the source code for
Trainer, it looks like my model’s
forward only needs to return an object with
ouputs[loss]. Is there anything else I need to do? Are there any resources/guides/tutorials for creating your own model?
Hello, I am newer to HuggingFace and wanted to create my own
Hi @Gabe, I’m not aware of any dedicated tutorials for building custom models, but my suggestion would be to subclass
PreTrainedModel (check out how e.g.
BertForSequenceClassification is implemented) or one of the existing model classes. This has several advantages to using
- You get all the helper methods like
- Your custom model will play nice with the
Depending on your use case, you can also override methods directly in the
Trainer - see here for a list of the available methods.
Hi @lewtun, sorry for the late reply. Thank you for the suggestion! I have been trying that method out with subclassing the
PreTrainedModel and using a separate
AutoModel as an encoder for the model. I have noticed that my models
from_pretrained does not seem to load the pretrained encoder. Is there any documentation/examples for this use case?
Hmm, that’s a bit odd. Would you be able to share a code snippet / Colab notebook with your workflow?
For the model/config I used:
class TagPredictionConfig(PretrainedConfig): def __init__(self, model_type: str = "bart", encoder_model: str = "facebook/bart-base", num_labels: int = 5, dropout: float = .5, inner_dim: int = 1024, max_len: int = 128, unique_label_count: int = 10, **kwargs): super(TagPredictionConfig, self).__init__(num_labels=num_labels, **kwargs) self.model_type = model_type self.encoder_model = encoder_model self.dropout = dropout self.inner_dim = inner_dim self.max_length = max_len self.unique_label_count = unique_label_count self.intent_token = '<intent>' self.snippet_token = '<snippet>' self.columns_used = ['snippet_tokenized', 'canonical_intent', 'tags'] encoder_config = AutoConfig.from_pretrained( self.encoder_model, ) self.vocab_size = encoder_config.vocab_size self.eos_token_id = encoder_config.eos_token_id class TagPredictionModel(PreTrainedModel): config_class = TagPredictionConfig def __init__(self, config: TagPredictionConfig): super(TagPredictionModel, self).__init__(config) self.config = config self.encoder = AutoModel.from_pretrained(self.config.encoder_model) self.encoder.resize_token_embeddings(self.config.vocab_size) self.dense_1 = nn.Linear( self.encoder.config.hidden_size, self.config.inner_dim, bias=False ) self.dense_2 = nn.Linear( self.config.inner_dim, self.config.unique_label_count, bias=False ) self.dropout = nn.Dropout(self.config.dropout) self.encoder._init_weights(self.dense_1) self.encoder._init_weights(self.dense_2) def forward( self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs): encoded = self.encoder( input_ids=input_ids, attention_mask=attention_mask, # labels=labels, return_dict=return_dict, ) hidden_states = encoded # last hidden state eos_mask = input_ids.eq(self.config.eos_token_id) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of <eos> tokens.") encoded_rep = hidden_states[eos_mask, :].view( hidden_states.size(0), -1, hidden_states.size(-1))[:, -1, :] classification_hidden = self.dropout(encoded_rep) classification_hidden = torch.tanh(self.dense_1(classification_hidden)) classification_hidden = self.dropout(classification_hidden) logits = torch.sigmoid(self.dense_2(classification_hidden)) loss = None if labels is not None: loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) return TagPredictionOutput( loss=loss, logits=logits )
Most of the code is straight from the
BertForSequenceClassification model. Still, because I want to use it with T5 and a multi-label classification task, I had to modify it slightly.
The trainer is (I can give the training args, but thought it would clutter too much, so I left them out):
args_dict = TrainingArguments(**simpleTrainingArgs("./experiments/")) trainer = Trainer( model=model, args=args_dict, train_dataset=datasets['train'], eval_dataset=datasets['val'], tokenizer=preprocessor.tokenizer ) trainer.train()
I think the main issue comes from the default config model name that it loads. Because when I do not set the default value for
encoder_model to a real model, it errors out due to the loading of a pretrained. This happened after the first epoch because it is trying to load the best model. I can do more testing and see if it happens in more than one epoch.
@lewtun to add on through more testing, I now get the warning:
Some weights of TagPredictionModel were not initialized from the model checkpoint at ./experiments/checkpoint-40 and are newly initialized: [’.encoder.shared.weight’, ‘.encoder.encoder.embed_tokens.weight’, ‘.encoder.encoder.embed_positions.weight’, ‘.encoder.encoder.layers.0.self_attn.k_proj.weight’…(Cut for length)
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
EDIT: Also, from rerunning evaluation on the validation set after training ends I am almost certain that it is not saving because the eval loss is different than during the training loop
hey @Gabe, sorry for the slow reply.
if i understand correctly, the problem you’re facing is that after the first epoch, the encoder continues to be initialised from the
facebook/bart-base checkpoint - is that right?
as you suspect, i think this line might be the problem
self.encoder = AutoModel.from_pretrained(self.config.encoder_model)
config.encoder_model would always point to whatever value you defined in the config. i wonder whether the problem can be solved by replacing
AutoModel.from_pretrained with a dedicated model class like
self.encoder = BartModel(config)
this is closer to what you see in the source code for
BertForSequenceClassification and (i think) ensures the model is loaded from
config.json associated with each epoch’s checkpoint.