Hello, I am newer to HuggingFace and wanted to create my own nn.Module
class that used RoBERTa as an encoder. I am also hoping that I would be able to use it with HuggingFace’s Trainer
class. Looking at the source code for Trainer
, it looks like my model’s forward
only needs to return an object with ouputs[loss]
. Is there anything else I need to do? Are there any resources/guides/tutorials for creating your own model?
Hi @gabeorlanski, I’m not aware of any dedicated tutorials for building custom models, but my suggestion would be to subclass PreTrainedModel
(check out how e.g. BertForSequenceClassification
is implemented) or one of the existing model classes. This has several advantages to using nn.Module
:
- You get all the helper methods like
from_pretrained
for free - Your custom model will play nice with the
Trainer
Depending on your use case, you can also override methods directly in the Trainer
- see here for a list of the available methods.
Hi @lewtun, sorry for the late reply. Thank you for the suggestion! I have been trying that method out with subclassing the PreTrainedModel
and using a separate AutoModel
as an encoder for the model. I have noticed that my models from_pretrained
does not seem to load the pretrained encoder. Is there any documentation/examples for this use case?
Hmm, that’s a bit odd. Would you be able to share a code snippet / Colab notebook with your workflow?
For the model/config I used:
class TagPredictionConfig(PretrainedConfig):
def __init__(self,
model_type: str = "bart",
encoder_model: str = "facebook/bart-base",
num_labels: int = 5,
dropout: float = .5,
inner_dim: int = 1024,
max_len: int = 128,
unique_label_count: int = 10,
**kwargs):
super(TagPredictionConfig, self).__init__(num_labels=num_labels, **kwargs)
self.model_type = model_type
self.encoder_model = encoder_model
self.dropout = dropout
self.inner_dim = inner_dim
self.max_length = max_len
self.unique_label_count = unique_label_count
self.intent_token = '<intent>'
self.snippet_token = '<snippet>'
self.columns_used = ['snippet_tokenized', 'canonical_intent', 'tags']
encoder_config = AutoConfig.from_pretrained(
self.encoder_model,
)
self.vocab_size = encoder_config.vocab_size
self.eos_token_id = encoder_config.eos_token_id
class TagPredictionModel(PreTrainedModel):
config_class = TagPredictionConfig
def __init__(self,
config: TagPredictionConfig):
super(TagPredictionModel, self).__init__(config)
self.config = config
self.encoder = AutoModel.from_pretrained(self.config.encoder_model)
self.encoder.resize_token_embeddings(self.config.vocab_size)
self.dense_1 = nn.Linear(
self.encoder.config.hidden_size,
self.config.inner_dim,
bias=False
)
self.dense_2 = nn.Linear(
self.config.inner_dim,
self.config.unique_label_count,
bias=False
)
self.dropout = nn.Dropout(self.config.dropout)
self.encoder._init_weights(self.dense_1)
self.encoder._init_weights(self.dense_2)
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
return_dict=None,
**kwargs):
encoded = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
# labels=labels,
return_dict=return_dict,
)
hidden_states = encoded[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
encoded_rep = hidden_states[eos_mask, :].view(
hidden_states.size(0), -1, hidden_states.size(-1))[:, -1, :]
classification_hidden = self.dropout(encoded_rep)
classification_hidden = torch.tanh(self.dense_1(classification_hidden))
classification_hidden = self.dropout(classification_hidden)
logits = torch.sigmoid(self.dense_2(classification_hidden))
loss = None
if labels is not None:
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
return TagPredictionOutput(
loss=loss,
logits=logits
)
Most of the code is straight from the BertForSequenceClassification
model. Still, because I want to use it with T5 and a multi-label classification task, I had to modify it slightly.
The trainer is (I can give the training args, but thought it would clutter too much, so I left them out):
args_dict = TrainingArguments(**simpleTrainingArgs("./experiments/"))
trainer = Trainer(
model=model,
args=args_dict,
train_dataset=datasets['train'],
eval_dataset=datasets['val'],
tokenizer=preprocessor.tokenizer
)
trainer.train()
I think the main issue comes from the default config model name that it loads. Because when I do not set the default value for encoder_model
to a real model, it errors out due to the loading of a pretrained. This happened after the first epoch because it is trying to load the best model. I can do more testing and see if it happens in more than one epoch.
@lewtun to add on through more testing, I now get the warning:
Some weights of TagPredictionModel were not initialized from the model checkpoint at ./experiments/checkpoint-40 and are newly initialized: [‘.encoder.shared.weight’, ‘.encoder.encoder.embed_tokens.weight’, ‘.encoder.encoder.embed_positions.weight’, ‘.encoder.encoder.layers.0.self_attn.k_proj.weight’…(Cut for length)
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
EDIT: Also, from rerunning evaluation on the validation set after training ends I am almost certain that it is not saving because the eval loss is different than during the training loop
hey @gabeorlanski, sorry for the slow reply.
if i understand correctly, the problem you’re facing is that after the first epoch, the encoder continues to be initialised from the facebook/bart-base
checkpoint - is that right?
as you suspect, i think this line might be the problem
self.encoder = AutoModel.from_pretrained(self.config.encoder_model)
because config.encoder_model
would always point to whatever value you defined in the config. i wonder whether the problem can be solved by replacing AutoModel.from_pretrained
with a dedicated model class like
self.encoder = BartModel(config)
this is closer to what you see in the source code for BertForSequenceClassification
and (i think) ensures the model is loaded from config.json
associated with each epoch’s checkpoint.