Top-k closest/similar words to the input word

Hello,

Given the pre-trained model how can I retrieve top-k closest words to the given word?
In othe5r words, how can I see vector representation of the word and top-k closest vectors (corresponding words) from pre-trained model?

1 Like

This is not a great fix, but what I use.

via @sgugger

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import torch.nn.functional as F

# Load model and tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Input example
input_txt = "Hello, my name is Sylvain."
inputs = tokenizer(input_txt, return_tensors='pt')
outputs = model(**inputs)

# If you are not on a source install, replace outputs.logits by outputs[0]
predictions = F.softmax(outputs.logits, dim=-1)

thresh = 1e-2
vocab_size = predictions.shape[-1]

# Predictions has one sentence (index 0) and we look at the last token predicted (-1)
idxs = torch.arange(0, vocab_size)[predictions[0][-1] >= thresh]
print(tokenizer.convert_ids_to_tokens(idxs))

You’d have to an input like this. At the core of the United States’ mismanagement of the Coronavirus lies its distrust of science. At the core of the United States’ mismanagement of the Coronavirus lies its

You can also do the same thing, but by masking.

from transformers import RobertaTokenizer, RobertaForMaskedLM

import torch

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

model = RobertaForMaskedLM.from_pretrained('roberta-base')

sentence = """At the core of the United States' mismanagement of the Coronavirus lies its distrust of science. At the core of the United States' mismanagement of the Coronavirus lies its <mask> of science."""

token_ids = tokenizer.encode(sentence, return_tensors='pt')

# print(token_ids)

token_ids_tk = tokenizer.tokenize(sentence, return_tensors='pt')

print(token_ids_tk)

masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero()

masked_pos = [mask.item() for mask in masked_position ]

print (masked_pos)

with torch.no_grad():

    output = model(token_ids)

last_hidden_state = output[0].squeeze()

print ("\n\n")

print ("sentence : ",sentence)

print ("\n")

list_of_list =[]

for mask_index in masked_pos:

    mask_hidden_state = last_hidden_state[mask_index]

    idx = torch.topk(mask_hidden_state, k=100, dim=0)[1]

    words = [tokenizer.decode(i.item()).strip() for i in idx]

    list_of_list.append(words)

    print (words)

    

best_guess = ""

for j in list_of_list:

    best_guess = best_guess+" "+j[0]
1 Like