How to write a custom configuration for hugging face model for Token Classification

Model description

I add simple custom pytorch-crf layer on top of TokenClassification model for NER. It will make the model more robust.

I train the model and I get the error:

***** Running training *****
  Num examples = 4
  Num Epochs = 2
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 1
  Total optimization steps = 4
TypeError: __init__() missing 3 required positional arguments: 'id2label', 'label2id', and 'num_labels'

Code

from torchcrf import CRF

model_checkpoint = "spanbert"
tokenizer = BertTokenizer.from_pretrained(model_checkpoint,add_prefix_space=True)
bert_model = BertForTokenClassification.from_pretrained(
                        model_checkpoint,id2label=id2label,label2id=label2id)
bert_model.config.output_hidden_states=True




class BertClassifierConfig(PretrainedConfig):
    
    
    model_type="BertForTokenClassification"
    
    def __init__(self,id2label ,label2id,num_labels,**kwargs):
        
        self.num_labels=num_labels
        self.id2label=id2label
        self.label2id=label2id
        self.output_hidden_states=True
        
        super().__init__(**kwargs)

Model

class BertForTokenClassification(PreTrainedModel):
   
    config_class =BertClassifierConfig
    def __init__(self, config,bert_model, num_labels):
      super(BertForTokenClassification, self).__init__(config)
      self.bert = bert_model
      self.dropout = nn.Dropout(0.25)
      self.classifier = nn.Linear(768, num_labels)
      self.crf = CRF(num_labels, batch_first = True)
    
    def forward(self, input_ids, attention_mask,  labels=None, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = torch.stack((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4])).mean(dim=0)
        sequence_output = self.dropout(sequence_output)
        
        emission = self.classifier(sequence_output) # [32,256,17]

        labels=labels.reshape(attention_mask.size()[0],attention_mask.size()[1])

        if labels is not None:

            loss = -self.crf(log_soft(emission, 2), labels, mask=attention_mask.type(torch.uint8), reduction='mean')
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return [loss, prediction]
                
        else:

            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return prediction

Saving

configuration = BertClassifierConfig(id2label ,label2id,num_labels=len(label2id))
model = BertForTokenClassification(configuration,bert_model, num_labels=len(label2id))
model.to(device)


args = TrainingArguments(
    "test0000",
    # evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=2,
    weight_decay=0.01,
    per_device_train_batch_size=2,
    # per_device_eval_batch_size=32
    fp16=True
    # bf16=True #Ampere GPU
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_data,
    # eval_dataset=train_data,
    # data_collator=data_collator,
    # compute_metrics=compute_metrics,
    tokenizer=tokenizer)

Saving

trainer.train()
trainer.save_model("modeltest")


AutoConfig.register("BertForTokenClassification", BertClassifierConfig)
AutoModel.register(BertClassifierConfig, BertForTokenClassification)

ERROR

 ***** Running training *****
      Num examples = 4
      Num Epochs = 2
      Instantaneous batch size per device = 2
      Total train batch size (w. parallel, distributed & accumulation) = 2
      Gradient Accumulation steps = 1
      Total optimization steps = 4
    TypeError: __init__() missing 3 required positional arguments: 'id2label', 'label2id', and 'num_labels'
1 Like

you don’t need to pass id2label ,label2id,num_labels to your custom config these are special keywords so you can just do:

class BertClassifierConfig(PretrainedConfig):
    model_type="BertForTokenClassification"

bert_model="bert-cased-base" #etc
configuration = BertClassifierConfig.from_pretrained(bert_model, id2label=label2id, num_labels=len(label2id), label2id= label2id)
model = BertForTokenClassification.from_pretrained(bert_model, config=configuration)

But if you need to pass some other kwargs through your config I don’t know how either.

1 Like