For the model/config I used:
class TagPredictionConfig(PretrainedConfig):
def __init__(self,
model_type: str = "bart",
encoder_model: str = "facebook/bart-base",
num_labels: int = 5,
dropout: float = .5,
inner_dim: int = 1024,
max_len: int = 128,
unique_label_count: int = 10,
**kwargs):
super(TagPredictionConfig, self).__init__(num_labels=num_labels, **kwargs)
self.model_type = model_type
self.encoder_model = encoder_model
self.dropout = dropout
self.inner_dim = inner_dim
self.max_length = max_len
self.unique_label_count = unique_label_count
self.intent_token = '<intent>'
self.snippet_token = '<snippet>'
self.columns_used = ['snippet_tokenized', 'canonical_intent', 'tags']
encoder_config = AutoConfig.from_pretrained(
self.encoder_model,
)
self.vocab_size = encoder_config.vocab_size
self.eos_token_id = encoder_config.eos_token_id
class TagPredictionModel(PreTrainedModel):
config_class = TagPredictionConfig
def __init__(self,
config: TagPredictionConfig):
super(TagPredictionModel, self).__init__(config)
self.config = config
self.encoder = AutoModel.from_pretrained(self.config.encoder_model)
self.encoder.resize_token_embeddings(self.config.vocab_size)
self.dense_1 = nn.Linear(
self.encoder.config.hidden_size,
self.config.inner_dim,
bias=False
)
self.dense_2 = nn.Linear(
self.config.inner_dim,
self.config.unique_label_count,
bias=False
)
self.dropout = nn.Dropout(self.config.dropout)
self.encoder._init_weights(self.dense_1)
self.encoder._init_weights(self.dense_2)
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
return_dict=None,
**kwargs):
encoded = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
# labels=labels,
return_dict=return_dict,
)
hidden_states = encoded[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
encoded_rep = hidden_states[eos_mask, :].view(
hidden_states.size(0), -1, hidden_states.size(-1))[:, -1, :]
classification_hidden = self.dropout(encoded_rep)
classification_hidden = torch.tanh(self.dense_1(classification_hidden))
classification_hidden = self.dropout(classification_hidden)
logits = torch.sigmoid(self.dense_2(classification_hidden))
loss = None
if labels is not None:
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
return TagPredictionOutput(
loss=loss,
logits=logits
)
Most of the code is straight from the BertForSequenceClassification
model. Still, because I want to use it with T5 and a multi-label classification task, I had to modify it slightly.
The trainer is (I can give the training args, but thought it would clutter too much, so I left them out):
args_dict = TrainingArguments(**simpleTrainingArgs("./experiments/"))
trainer = Trainer(
model=model,
args=args_dict,
train_dataset=datasets['train'],
eval_dataset=datasets['val'],
tokenizer=preprocessor.tokenizer
)
trainer.train()
I think the main issue comes from the default config model name that it loads. Because when I do not set the default value for encoder_model
to a real model, it errors out due to the loading of a pretrained. This happened after the first epoch because it is trying to load the best model. I can do more testing and see if it happens in more than one epoch.