How do we reassemble sub tokens when running a token classification model in inference with a sentence?

I fintuned bert model for token classification task, and but I want to add a post-processing task to reassemble the generated subtokens (and align the corresponding tags) when testing on a new sentence

For example :

import torch
import numpy as np
from transformers import BertTokenizer, BertConfig

model = torch.load('path/finetuned_bert.pth')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

tag_values = ["O","ORG"]
tag2idx = {t: i for i, t in enumerate(tag_values)}

test_sentence = """McDonald's is a well-known fast food chain"""
tokenized = tokenizer.encode(test_sentence)
input_ids = torch.tensor([tokenized])

with torch.no_grad():
    output = model(input_ids)

label_indices = np.argmax(output[0].to('cpu').numpy(), axis=2)
tokens = tokenizer.convert_ids_to_tokens('cpu').numpy()[0])

# join  split tokens
new_tokens, new_labels = [], []
for token, label_idx in zip(tokens[1:-1], label_indices[0][1:-1]):
    if token.startswith("##"):
        new_tokens[-1] = new_tokens[-1] + token[2:]

for token, label in zip(new_tokens, new_labels):
    print("{}\t{}".format(label, token))

# output : 

# ORG	mcdonald
# O	'
# O	s
# O	is
# O	a
# O	well
# O	-
# O	known
# O	fast
# O	food
# O	chain

I want to get a result similar to this one:

# output : 

# ORG	McDonald's
# O	is
# O	a
# O	well-known
# O	fast
# O	food
# O	chain

Any help would be appreciated.

@clem @mariosasko @merve @sgugger do you have an idea?
Thanks in advance

You may want to use ā€œpipelineā€ with an aggregation strategy:

from transformers import pipeline

nlp = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")

For more information on the aggregation strategy, you can read this

You will not obtain exactly the output you want but a json like this one:

 [{'entity': 'LABEL_1',
  'score': 0.6112772,
  'index': 1,
  'word': 'token_1',
  'start': 0,
  'end': 5},
 {'entity': 'LABEL_1',
  'score': 0.532399,
  'index': 2,
  'word': 'token_2',
  'start': 6,
  'end': 7},
 {'entity': 'LABEL_1',
  'score': 0.5236228,
  'index': 3,
  'word': 'token_3',
  'start': 8,
  'end': 10},