I was trying to solve this problem
We’re going to use the wikitext (link)
dataset with the distilbert-base-cased (link)
model checkpoint.
Start by loading the wikitext-2-raw-v1
version of that dataset, and take the 11th example (index 10) of the train
split.
We’ll tokenize this using the appropriate tokenizer, and we’ll mask the sixth token (index 5) the sequence.
When using the distilbert-base-cased
checkpoint to unmask that (sixth token, index 5) token, what is the most probable predicted token (please provide the decoded token, and not the ID)?
and I coded the solution as
import torch
import transformers
import nlp
# Load the Wikitext-2 dataset
dataset = nlp.load_dataset('wikitext', 'wikitext-2-raw-v1')
# Get the 11th example (index 10) of the train split
example = dataset['train'][10]
# Load the DistilBERT model and tokenizer
model = transformers.DistilBertModel.from_pretrained('distilbert-base-cased')
tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-cased')
# Tokenize the example
input_ids = tokenizer.encode(example['text'], return_tensors='pt')
# Mask the sixth token (index 5) in the sequence
masked_input_ids = input_ids.clone()
masked_input_ids[:, 5] = tokenizer.mask_token_id
# Use the model to predict the most probable token for the masked token
output = model(masked_input_ids)[0]
prediction_scores, prediction_indexes = output[:, 5, :].max(dim=-1)
# prediction_scores, prediction_indexes = output[:, 5, :].max(dim=-1)
# Decode the predicted token ID to obtain the actual token
predicted_token = tokenizer.decode(prediction_indexes, skip_special_tokens=True)
# Replace the masked token with the predicted token in the input sequence
decoded_input_ids = input_ids.squeeze().tolist()
decoded_input_ids[5] = prediction_indexes.item()
decoded_input = tokenizer.decode(decoded_input_ids, skip_special_tokens=True)
print(f'Input: {example["text"]}')
print(f'Predicted token: {predicted_token}')
print(f'Decoded input: {decoded_input}')
The outout I get is
Predicted token: े
Decoded input: The game’s े system, the BliTZ system, is carried over directly from Valkyira Chronicles. During missions, players select each unit using a top @ - @ down perspective of the
What am I doing wrong here?