Having Multiple [MASK] tokens in a sentence

I would like to have multiple [MASK] tokens in a sentence but I get an error when I try to run it.
What do I need to change to fix it?

Instead of: text = "The capital of France, " + tokenizer.mask_token + “, contains the Eiffel Tower.”
I need: text = "The capital of France, " + tokenizer.mask_token + ", contains the Eiffel + tokenizer.mask_token "

from transformers import BertTokenizer, BertForMaskedLM
from torch.nn import functional as F
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased',    return_dict = True)
text = "The capital of France, " + tokenizer.mask_token + ", contains the Eiffel Tower."
input = tokenizer.encode_plus(text, return_tensors = "pt")
mask_index = torch.where(input["input_ids"][0] == tokenizer.mask_token_id)
output = model(**input)
logits = output.logits
softmax = F.softmax(logits, dim = -1)
mask_word = softmax[0, mask_index, :]
top_10 = torch.topk(mask_word, 10, dim = 1)[1][0]
for token in top_10:
   word = tokenizer.decode([token])
   new_sentence = text.replace(tokenizer.mask_token, word)

I’ve used the code from here

I’ve already looked at Multiple Mask Tokens but I want the output to be a sentence.

I hope you can help me

Kind regards

import torch

sentence = "The capital of France [MASK] contains the Eiffel [MASK]."

token_ids = tokenizer.encode(sentence, return_tensors='pt')

# print(token_ids)

token_ids_tk = tokenizer.tokenize(sentence, return_tensors='pt')


masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero()

masked_pos = [mask.item() for mask in masked_position ]

print (masked_pos)

with torch.no_grad():

    output = model(token_ids)

last_hidden_state = output[0].squeeze()

print ("\n\n")

print ("sentence : ",sentence)

print ("\n")

list_of_list =[]

for mask_index in masked_pos:

    mask_hidden_state = last_hidden_state[mask_index]

    idx = torch.topk(mask_hidden_state, k=100, dim=0)[1]

    words = [tokenizer.decode(i.item()).strip() for i in idx]


    print (words)

best_guess = ""

for j in list_of_list:

    best_guess = best_guess+" "+j[0]

Thank you so much! This helped me so much.

1 Like