I am trying to generate text using GPT2. I am using the code snippet at https://huggingface.co/transformers/quickstart.html (reproduced below). Unfortunately, it gives an error
import torch
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')
generated = tokenizer.encode("The Manhattan bridge")
context = torch.tensor([generated])
past = None
for i in range(100):
print(i)
output, past = model(context, past=past)
token = torch.argmax(output[..., -1, :])
generated += [token.tolist()]
context = token.unsqueeze(0)
sequence = tokenizer.decode(generated)
print(sequence)
Error is in the line token = torch.argmax(output[0, -1, :])
, saying that TypeError: string indices must be integers
. Can someone please help me out?
1 Like
2 years late, but did you find a solution?
No, unfortunately I didn’t…
Please use this code for the text generation. make sure you have transformer and torch is installed.
- pip install torch
- pip install transformers
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
generated = tokenizer.encode("The Manhattan bridge", return_tensors='pt')
output_sequences = model.generate(
input_ids=generated,
max_length=150,
num_return_sequences=1,
no_repeat_ngram_size=2,
temperature=0.7,
top_k=50,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
print(text)
Note: You can fine tune the parameters for the best answer specially temperature.
Hope at least you got answer even after 3 years
I just look your post and thought to answer you.