Hi
I am working on extracting legal entities (date) from a corpus of agreements. In the training set, I have tokenized the agreement text, date labels are tagged in IOB convetion and fed to the distilbertbase uncased model. Defined training arguements, data collator and compute matrix methods. After the training is completed, I input the preprocessed prediction dataset which has only agreement text and I need to predict the date label in it.
I am getting a tensor array as output and not aware of how to extract the date labels from this array
Can I get support on this issue ?
In case of NER, one typically uses an xxxForTokenClassification
model (which adds a linear layer on top of the base Transformer model). The logits of such models are typically of shape (batch_size, seq_len, num_labels). Let’s take an existing, fine-tuned BertForTokenClassification model from the hub and perform inference on a new, unseen text:
from transformers import AutoTokenizer, BertForTokenClassification
model_name = "dslim/bert-base-NER"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertForTokenClassification.from_pretrained(model_name)
Let’s prepare a new text for the model:
text = "Obama was the president of the United States and he was born in Hawai."
encoding = tokenizer(text, return_tensors="pt")
We now forward it through the model:
# forward pass
outputs = model(**encoding)
We now take the logits from the outputs, which are the scores that the model gives for each of the classes. In case of token classification, the logits are of shape (batch_size, seq_len, num_labels). Let’s check the shape:
logits = outputs.logits
print(logits.shape)
This prints torch.Size([1, 18, 9]). The batch size is 1 as we only have a single sentence, we have a sequence length of 18 tokens, and the number of labels is 9. So apparently this model classifies each token to belong to 1 of 9 possible labels. We can get the predictions by performing an argmax on the last dimension (i.e., the labels dimension), as follows:
predicted_label_classes = logits.argmax(-1)
print(predicted_label_classes)
This prints:
tensor([[0, 3, 0, 0, 0, 0, 0, 7, 8, 0, 0, 0, 0, 0, 7, 7, 0, 0]])
Let’s now convert each predicted class index to the corresponding label name using the id2label mapping of the model’s configuration, as follows:
predicted_labels = [model.config.id2label[id] for id in predicted_label_classes.squeeze().tolist()]
print(predicted_labels)
This prints:
['O', 'B-PER', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'B-LOC', 'O', 'O']
To see the correspondence between the tokens and the predicted labels, let’s print them side-by-side:
for id, label in zip(encoding.input_ids.squeeze().tolist(), predicted_labels):
print(tokenizer.decode([id]), label)
This prints:
[CLS] O
Obama B-PER
was O
the O
president O
of O
the O
United B-LOC
States I-LOC
and O
he O
was O
born O
in O
Ha B-LOC
##wai B-LOC
. O
[SEP] O
So now we can clearly see the predictions. However, as the model uses subword tokenization, we would like to convert those to word-level predictions.
Here, a number of aggregation strategies apply. One strategy is to just select the prediction for the first token of each word, another strategy is to average the predictions for all tokens of a word, another strategy is to take the biggest logit for all tokens of a word, etc. This depends on how the model was fine-tuned.
The pipeline API of HuggingFace supports various aggregation strategies, and abstracts away all of what I did above + grouping the entities for the user. You can call it as follows:
from transformers import pipeline
nlp = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first")
nlp(text)
This prints:
[{'end': 5,
'entity_group': 'PER',
'score': 0.998693,
'start': 0,
'word': 'Obama'},
{'end': 44,
'entity_group': 'LOC',
'score': 0.9994223,
'start': 31,
'word': 'United States'},
{'end': 69,
'entity_group': 'LOC',
'score': 0.9988722,
'start': 64,
'word': 'Hawai'}]
Under the hood, it will use the offsets_mapping
(which is only supported by fast tokenizers) to know to which word each token belongs. You can check the source code here.