How to add a custom CRF head on top of BERT for token classification?

I am currently working on a token classification task using PyTorch and the pre-trained BERT model. I want to add a custom CRF head on top of the BERT model. I have tried several code snippets like the one below, but it’s not working.

from torchcrf import CRF
def forward(self, input_ids,token_type_ids,attention_mask,labels):  
        bert = BertModel.from_pretrained('pretrained-bert', num_labels=2)
config=bert.config

class BERT_CRF(nn.Module):
    def __init__(self):
        super(BERT_CRF, self).__init__()
        self.bert = bert
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(config.num_labels, batch_first=True)

    def forward(self, input_ids,token_type_ids,attention_mask,labels):
        outputs=self.bert(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        loss = -self.crf(emissions = logits, tags=labels, mask=attention_mask)
        return loss

Error


IndexError Traceback (most recent call last)
Cell In[56], line 5
3 for batch in train_dataloader:
4 model.zero_grad()
----> 5 loss = model(**batch.to(device))
6 print(loss)
7 break

File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don’t have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = ,

Cell In[55], line 15, in BERT_CRF.forward(self, input_ids, token_type_ids, attention_mask, labels)
13 emission = self.classifier(sequence_output)
14 attn_masks = attention_mask.type(torch.uint8)
—> 15 loss = -self.crf(log_soft(emission, 2), labels, mask=attn_masks, reduction=‘mean’)
16 return loss

File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don’t have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = ,

File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torchcrf_init_.py:102, in CRF.forward(self, emissions, tags, mask, reduction)
99 mask = mask.transpose(0, 1)
101 # shape: (batch_size,)
→ 102 numerator = self._compute_score(emissions, tags, mask)
103 # shape: (batch_size,)
104 denominator = self._compute_normalizer(emissions, mask)

File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torchcrf_init_.py:186, in CRF._compute_score(self, emissions, tags, mask)
182 mask = mask.float()
184 # Start transition score and first emission
185 # shape: (batch_size,)
→ 186 score = self.start_transitions[tags[0]]
187 score += emissions[0, torch.arange(batch_size), tags[0]]
189 for i in range(1, seq_length):
190 # Transition score to next tag, only added if next timestep is valid (mask == 1)
191 # shape: (batch_size,)

IndexError: index -100 is out of bounds for dimension 0 with size 2

I believe you added -100 as a label to be ignored by the loss (as it is the case with the CrossEntropy loss). But the default implementation of pytorch-crf package does not allow to ignore labels (except the final paddings).

One solution I have found is this implementation https://github.com/modelscope/AdaSeq/blob/master/adaseq/modules/decoders/partial_crf.py that allows you to mask a specific label.