I’m trying to denoise text using a T5 model following the Huggingface doc:
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss
loss.item()
But I can’t figure out how to get the actual text that corresponds to the masked input. They only show how to get the loss and mention
the forward function automatically creates the correct decoder_input_ids
I don’t care for the loss, nor do I have labels in my setting. I just have text with masked tokens that I need to fill:
my_masked_text = [
"The kid went to the [MASK]",
"The dog likes [MASK] and also [MASK]"
]
Hi,
That’s explained in the docs:
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
sequence_ids = model.generate(input_ids)
sequences = tokenizer.batch_decode(sequence_ids)
sequences
Note that T5 doesn’t use a MASK token, but rather “sentinel tokens”, like the <extra_id_0>. See the original T5 paper for details.
The output doesn’t make any sense though:
['<pad><extra_id_0> park offers<extra_id_1> the<extra_id_2> park.</s>']
Does it thinks the sentence should be
The park offers walks in the park park
I tried other ones and it’s even worse:
input_ids = tokenizer("The kid went to the <extra_id_1> to buy chocolate", return_tensors="pt").input_ids
Output:
['<pad><extra_id_0> kiddie kid<extra_id_1> kiddie<extra_id_2> kiddie kid kid<extra_id_3> kiddie kid<extra_id_4>kid kid']
Where did all these extra IDs came from?