How to use additional input features for NER?

Hello,

I’ve been following the documentation on fine-tuning custom datasets (Fine-tuning with custom datasets — transformers 4.3.0 documentation), I was wondering how additional token level features can be used as input (e.g.POS tags).

My intuition was to concatenate each token with the tag before feeding it into a pre-trained tokenizer (e.g [“Arizona_NNP”, “Ice_NNP”, “Tea_NNP”]). Is this the right way to do it? Is there a better way to do it?

Thank you in advance!

1 Like

Actually no, because the pre-trained tokenizer only knows tokens, not tokens + POS tags. A better way to do this would be to create an additional input to the model (besides input_ids and token_type_ids) called pos_tag_ids, for which you can add an additional embedding layer (nn.Embedding). In that way, you can sum the embeddings of the tokens, token types and the POS tags. Let’s illustrate this for a pre-trained BERT model:

We first have to modify the BertEmbeddings class. In short, we’ll add an embedding layer for the POS tags:

class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        self.pos_tag_embeddings = nn.Embedding(max_number_of_pos_tags, config.hidden_size)

        (...)
  
    def forward(
    self, input_ids=None, pos_tag_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
    if input_ids is not None:
        input_shape = input_ids.size()
    else:
        input_shape = inputs_embeds.size()[:-1]

    seq_length = input_shape[1]

    if position_ids is None:
        position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

    if token_type_ids is None:
        token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

    if inputs_embeds is None:
        inputs_embeds = self.word_embeddings(input_ids)
    token_type_embeddings = self.token_type_embeddings(token_type_ids)
    pos_tag_embeddings = self.pos_tag_embeddings(pos_tag_ids)

    embeddings = inputs_embeds + token_type_embeddings + pos_tag_embeddings
    if self.position_embedding_type == "absolute":
        position_embeddings = self.position_embeddings(position_ids)
        embeddings += position_embeddings
    embeddings = self.LayerNorm(embeddings)
    embeddings = self.dropout(embeddings)
    return embeddings

The max_number_of_pos_tags is the total unique number of POS tags we have (might be 20 for example, with NNP being one of them), also called the “vocabulary size” of the embedding layer. The config.hidden_size is the size of the embedding vector that we want to learn for each POS tag (which is 768 by default for BERT-base). We would also need to modify the forward pass of BertModel a bit to add the additional input pos_tag_ids:

def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        pos_tag_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

     (...)

      embedding_output = self.embeddings(
        input_ids=input_ids,
        position_ids=position_ids,
        token_type_ids=token_type_ids,
        pos_tag_ids=pos_tag_ids,
        inputs_embeds=inputs_embeds,
        past_key_values_length=past_key_values_length,
    )
     
     (...)

Now that we have modified the model (modeling_bert.py), let’s move on to provide actual inputs to the model. An additional complexity of BERT-like models is that they rely on subword tokens, rather than words. This means that a word like “Arizona” might be tokenized into [“Ari”, “##zona”]. This means that we will also have to provide POS tags at the token level. Similar to how each token is turn into an integer (input_ids), we will also have to turn each POS tag into a corresponding integer (pos_tag_ids) in order to provide it to the model. So we would actually need to keep a dictionary that maps each POS tag to a corresponding integer.

For simplicity, let’s assume that we only have two POS tags, namely NNP and VNP. We create corresponding integers (pos_tag_ids) for them, for example [0, 1]. So our vocabulary size of the POS tag embedding layer is only 2. Let’s now provide an example sentence to the model:

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
text = "She sells"
# if we tokenize it, this becomes:
encoding = tokenizer(text, return_tensors="pt") # this creates a dictionary with keys 'input_ids' etc.
# we add the pos_tag_ids to the dictionary
pos_tags = [NNP, VNP]
encoding['pos_tag_ids'] = torch.tensor([[0, 1]])

# next, we can provide this to our modified BertModel:
from tranformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased")
outputs = model(**encoding)

Note that the code above assumes that each word is turned into a single token, which is typically not the case for other words. So suppose that the word Arizona is tokenized into [“Ari”, “##zona”], then we would have pos_tag_ids [0,0] for example.

5 Likes

Thank you so much for the detailed answer!

For addressing subtoken labeling, I can generally follow the same method detailed in the finetuning custom datasets documentation right? Only change is instead of defaulting to -100 for subsequent subtokens, I label all subtokens as the POS ID for the original token?

Actually, that’s a design choice, you can label all subtokens of a word with the same label, or (and this is more commonly done), only label the first subtoken of a word and label the rest with -100, such that they will not be taken into account by the loss function.

1 Like

Important note: @nielsr’s approach is definitely reasonable but I would argue that you should use a separate optimizer for the POS embeddings when you finetune. The reason being that the main model (+ embeddings) are already pretrained, whereas the POS embeddings are not. You’d likely need a larger lr for those new embeddings.

An alternative approach is adding layers on top of the model which concatenate POS features, e.g. one-hot encoded, and pass it to an RNN for instance.

1 Like

@BramVanroy I think you can use a single optimizer, but specify a different learning rate for the parameters you want, like so:

optimizer = optim.Adam([
                {'params': model.parameters()},
                {'params': model.embeddings.pos_tag_embeddings(), 'lr': 5e-5}
            ], lr=2e-5)

Yes, I worded it incorrectly but I hope it is clear what I meant, as I mentioned lr.

I don’t think your snippet is completely correct though (missing parameters in the second dict), and I am also not sure whether you can include the same params in two separate dicts, or what the consequences are when you do this. (In your snippet the pos embeddings are included in both model.parameters() and the second params dict.) Might need to make that distinction exclusive, but
I am not sure on that.

encoding['pos_tag_ids'] = torch.tensor([[0, 1]]) should be encoding['pos_tag_ids'] = torch.tensor([0, 1]) for each text. My dimensions are not matching with your method.

Even after doing this, my trainer API is detecting it as single dimension as opposed to labels. Is there anything additional needs be done while adding key-value pair to dictionary of tokeniser?

Ok, I finally fixed it, and made a clean method to add new features to any Bert model. Thanks for your help guys.

Ok, I tried this method with Trainer API, so I added the additional token dictionary, pos_tag_ids, but it doesn’t seem to work. Either I am getting permute error or batch size error. Are you sure about the tensor assignment you provided? encoding[‘pos_tag_ids’] = torch.tensor([[0, 1]]).

Any help would be appreciated

Actually, my tutorial was a bit simplistic (can’t seem to edit my tutorial above). Let’s take a more realistic example. Suppose that you have a list of words like [“My”, “name”, “is”, “Niels”], and the corresponding POS tags are [DET, NOUN, AUX, PROPN]. Here’s how to prepare the additional input features:

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

words = ["My", "name", "is", "Niels"]
pos_tags = ["DET", "NOUN", "AUX", "PROPN"]


tag2id = {'NA': 0, 'DET': 1, 'NOUN':2, 'AUX':3, 'PROPN':4}
id2tag = {v:k for k,v in tag2id.items()}

tokens = []
pos_tag_tokens = []
for word, tag in zip(words, pos_tags):
  # tokenize the word
  word_tokens = tokenizer.tokenize(word)
  tokens.extend(word_tokens)
  # copy the POS tag for all word tokens
  pos_tag_tokens.extend([tag for _ in range(len(word_tokens))])

# Truncation: account for [CLS] and [SEP] with "- 2". 
special_tokens_count = 2 
max_seq_length = 512
if len(tokens) > max_seq_length - special_tokens_count:
    tokens = tokens[: (max_seq_length - special_tokens_count)]
    pos_tag_tokens = pos_tags_tokens[: (max_seq_length - special_tokens_count)]

# add special tokens + corresponding POS tags
tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token]
pos_tag_tokens = ['NA'] + pos_tag_tokens + ['NA']

# create input_ids + attention_mask
input_ids = tokenizer.convert_tokens_to_ids(tokens)
attention_mask = [1] * len(input_ids)
print(pos_tag_tokens)
pos_tag_ids = [tag2id[tag] for tag in pos_tag_tokens]

# padding up to max_seq_length
padding_length = max_seq_length - len(input_ids)
input_ids += [tokenizer.pad_token_id] * padding_length
attention_mask += [0] * padding_length
pos_tag_ids += [0] * padding_length

print(tokenizer.convert_ids_to_tokens(input_ids))
print(pos_tag_ids)

In reality, we also need to add POS tag IDs for special tokens ([CLS], [SEP] and [PAD]) - I’m setting the POS tag id for the special tokens to 0, which means “NA” (not applicable). Moreover, it is possible that a word is tokenized into several tokens, hence we must create these features for each of the tokens of a given word.

Now we can give this as input to the model:

from transformers import BertForTokenClassification
import torch

model = BertForTokenClassification.from_pretrained("bert-base-uncased")

input_ids = torch.tensor(input_ids).unsqueeze(0) # batch size of 1
attention_mask = torch.tensor(attention_mask).unsqueeze(0) # batch size of 1
pos_tag_ids = torch.tensor(pos_tag_ids).unsqueeze(0) # batch size of 1

outputs = model(input_ids=input_ids, attention_mask=attention_mask, pos_tag_ids=pos_tag_ids)
4 Likes

Thanks for the reply, I somehow managed to follow your method. Now I am not getting desired accuracy, and was setting up optimisers for pos_tag_embedding layer.

optimizer1 = torch.optim.Adam([ 
                                {'params': model.bert.embeddings.pos_tag_embeddings.parameters(), 'lr': 5e-5},
                                {'params': model.bert.parameters()},
                              ], lr=2e-5)

optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
model.base_model.embeddings.parameters

This is throwing me error, any idea how to train this embedding layer accurately?

I’ll add special tokens for embedding layer, I didn’t do that yet. Thanks so much.

I was wondering would it be a nice solution, if instead of providing a new embedding layer to the model, why aren’t we focusing on providing additional features to the classification layer, and let it learn the custom classification layer during fine-tuning. Would it be right approach?

Yes that’s also an option.

1 Like

@f3n1Xx did you decide to add features to your classificaton layer? This is what I am working on.
@nielsr do you think this method would have better results?

i am doing both, i have to include performance for both methods in my work

Thanks @f3n1Xx. Would you mind sharing how you added new features to your model?

I didn’t check it completely, but my first thought was to concatenate additional data to the pooled output of transformer.

My second thought is to create two seperate classification layers, first one takes pooled output, second takes the first layers input and additional features and gives logits.

1 Like

If anyone is still stuck in this problem, I have implemented in the case for a biological problem, so don’t know how well it would translate to NLP.
https://repository.tudelft.nl/islandora/object/uuid%3A80fef18f-2b3c-4993-b5d8-10ca6dbf1b71

1 Like