I am currently working on a token classification task using PyTorch and the pre-trained BERT model. I want to add a custom CRF head on top of the BERT model. I have tried several code snippets like the one below, but it’s not working.
from torchcrf import CRF
def forward(self, input_ids,token_type_ids,attention_mask,labels):
bert = BertModel.from_pretrained('pretrained-bert', num_labels=2)
config=bert.config
class BERT_CRF(nn.Module):
def __init__(self):
super(BERT_CRF, self).__init__()
self.bert = bert
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.crf = CRF(config.num_labels, batch_first=True)
def forward(self, input_ids,token_type_ids,attention_mask,labels):
outputs=self.bert(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = -self.crf(emissions = logits, tags=labels, mask=attention_mask)
return loss
Error
IndexError Traceback (most recent call last)
Cell In[56], line 5
3 for batch in train_dataloader:
4 model.zero_grad()
----> 5 loss = model(**batch.to(device))
6 print(loss)
7 break
File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don’t have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = ,
Cell In[55], line 15, in BERT_CRF.forward(self, input_ids, token_type_ids, attention_mask, labels)
13 emission = self.classifier(sequence_output)
14 attn_masks = attention_mask.type(torch.uint8)
—> 15 loss = -self.crf(log_soft(emission, 2), labels, mask=attn_masks, reduction=‘mean’)
16 return loss
File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don’t have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = ,
File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torchcrf_init_.py:102, in CRF.forward(self, emissions, tags, mask, reduction)
99 mask = mask.transpose(0, 1)
101 # shape: (batch_size,)
→ 102 numerator = self._compute_score(emissions, tags, mask)
103 # shape: (batch_size,)
104 denominator = self._compute_normalizer(emissions, mask)
File C:\ProgramData\Anaconda3\envs\llm\lib\site-packages\torchcrf_init_.py:186, in CRF._compute_score(self, emissions, tags, mask)
182 mask = mask.float()
184 # Start transition score and first emission
185 # shape: (batch_size,)
→ 186 score = self.start_transitions[tags[0]]
187 score += emissions[0, torch.arange(batch_size), tags[0]]
189 for i in range(1, seq_length):
190 # Transition score to next tag, only added if next timestep is valid (mask == 1)
191 # shape: (batch_size,)
IndexError: index -100 is out of bounds for dimension 0 with size 2