Text generation with XLNet not working

I’m having some trouble with text generation when using XLNet. Here is my code:

from transformers import  XLNetLMHeadModel, XLNetTokenizer
import torch
from time import time
from torchsummary import summary

PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> """

def prepare_xlnet_input(tokenizer, prompt_text):
    prompt_text = PADDING_TEXT + prompt_text
    return prompt_text

tokenizer = XLNetTokenizer.from_pretrained("xlnet-large-cased")
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased',mem_len=0)

prompt_text = "How are you "
input_text = prepare_xlnet_input(tokenizer,prompt_text)
generated = tokenizer.encode(prompt_text, add_special_tokens=False)
context = torch.tensor([ tokenizer.encode(input_text, add_special_tokens=False)])
past = None

length = 10

tic = time()
for i in range(length):
    print(i)
    #Create dummy token for input
    input_data = torch.cat([context,torch.zeros((1,1),dtype=torch.int64)],dim=1)

    #Create target mapping mask
    target_mapping = torch.zeros((1,1,input_data.shape[1]))
    target_mapping[:,:,-1] = 1
    print(target_mapping.shape)

    #Create permutation mask
    perm_mask = torch.zeros(1,input_data.shape[1],input_data.shape[1])
    perm_mask[:,-1,:] = 1
    print('Perm mask shape',perm_mask.shape)

    #Run data through model
    print('Input data: ', input_data[0,-10:-1])
    results = model(input_data, mems=past, use_cache=False, target_mapping = target_mapping, perm_mask = perm_mask)
    output = results[0]
    #past = results[1]

    #Get most probable token (greedy)
    print('Results: ',len(results))
    print('Output: ',output[0,0,0:20])
    print('Output shape: ',output.shape)
    token = torch.argmax(output,dim=-1)

    print('Token: ',token)
    generated += [token.squeeze(0).squeeze(0).tolist()]
    context = torch.cat([context, token],dim=1)

print(f'Time to decode {length} tokens: ',time() - tic)
sequence = tokenizer.decode(generated)
print('################### OUTPUT TEXT #####################')
print(sequence)

And here is the output:

......
8
torch.Size([1, 1, 172])
Perm mask shape torch.Size([1, 172, 172])
Input data:  tensor([44, 19, 19, 19, 19, 19, 19, 19, 19])
Results:  1
Output:  tensor([-11.8758, -22.3436, -22.2104, -21.2482, -19.0063, -22.1941, -22.2687,
        -13.8716,  -9.6080,  -5.0403, -11.9773, -10.5682,  -9.9979,  -8.1426,
        -10.3780, -19.4922, -18.2674,  -7.5363,  -5.8832,  -3.6131],
       grad_fn=<SliceBackward>)
Output shape:  torch.Size([1, 1, 32000])
Token:  tensor([[19]])
9
torch.Size([1, 1, 173])
Perm mask shape torch.Size([1, 173, 173])
Input data:  tensor([19, 19, 19, 19, 19, 19, 19, 19, 19])
Results:  1
Output:  tensor([-12.3357, -22.8729, -22.7428, -21.7369, -19.5315, -22.7247, -22.7972,
        -14.3732, -10.1052,  -5.5922, -12.4161, -11.0451, -10.4621,  -8.6313,
        -10.8699, -19.9730, -18.7067,  -8.0223,  -6.4285,  -4.0735],
       grad_fn=<SliceBackward>)
Output shape:  torch.Size([1, 1, 32000])
Token:  tensor([[19]])
Time to decode 10 tokens:  11.015855312347412
################### OUTPUT TEXT #####################
How are you,,,,,,,,,,

It just predicts commas the entire way through.
Another strange thing I noticed is that the predicted logits steadily increase with every cycle, they don’t relatively change at all, they just increase.

It’s very similar to the example code. I tried the code from https://huggingface.co/transformers/model_doc/xlnet.html?highlight=tfxlnet#xlnetlmheadmodel . Although it gave the wrong output, it did give something:

from transformers import XLNetTokenizer, XLNetLMHeadModel

import torch

tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')

model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')

# We show how to setup inputs to predict a next token using a bi-directional context.

input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)).unsqueeze(0)  # We will predict the masked token

perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)

perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token

target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)  # Shape [1, 1, seq_length] => let's predict one token

target_mapping[0, 0, -1] = 1.0  # Our first (and only) prediction will be the last token of the sequence (the masked token)

outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)

next_token_logits = outputs[0]  # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]

# The same way can the XLNetLMHeadModel be used to be trained by standard auto-regressive language modeling.

input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)).unsqueeze(0)  # We will predict the masked token

labels = torch.tensor(tokenizer.encode("cute", add_special_tokens=False)).unsqueeze(0)

assert labels.shape[0] == 1, 'only one word will be predicted'

perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)

perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token as is done in standard auto-regressive lm training

target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)  # Shape [1, 1, seq_length] => let's predict one token

target_mapping[0, 0, -1] = 1.0  # Our first (and only) prediction will be the last token of the sequence (the masked token)

outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping, labels=labels)

loss, next_token_logits = outputs[:2]  # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]

next_logit = torch.argmax(next_token_logits[0,0,:])

next_token = tokenizer.convert_ids_to_tokens(int(next_logit))

print(next_token)

After that it predicts the word ‘very’ instead of the intended ‘cute’.

However running run_generation.py using XLNet works just fine. Even when I changed to be greedy and not sample, it still produced valid results. I verified that the shape of my input data, the input data itself, the permutation mask and the target mapping is all the same as what is used in run_generation.py, but I just can’t get valid results out of my code.

What am I doing wrong?

Turns out I made a mistake with the permutation mask.

perm_mask[:,-1,:] = 1 should be perm_mask[:,:,-1] = 1

There’s a day of my life I’ll never get back. :sweat_smile:

2 Likes