@MoritzLaurer 's code is really cool. But it will most likely strip off the special tokens and its only going to pass the input_ids to the model. I thought why not pass the model whatever it has originally seen during its training. So, I have re-written the logic below and also taken care of not breaking the sentences into parts with the help of nltk
. @echatzikyriakidis
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
nltk.download('punkt')
checkpoint = "google/pegasus-xsum"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
long_text = "This is a very very long text. " * 300
sentences = nltk.tokenize.sent_tokenize(long_text)
# initialize
length = 0
chunk = ""
chunks = []
count = -1
for sentence in sentences:
count += 1
combined_length = len(tokenizer.tokenize(sentence)) + length # add the no. of sentence tokens to the length counter
if combined_length <= tokenizer.max_len_single_sentence: # if it doesn't exceed
chunk += sentence + " " # add the sentence to the chunk
length = combined_length # update the length counter
# if it is the last sentence
if count == len(sentences) - 1:
chunks.append(chunk) # save the chunk
else:
chunks.append(chunk) # save the chunk
# reset
length = 0
chunk = ""
# take care of the overflow sentence
chunk += sentence + " "
length = len(tokenizer.tokenize(sentence))
# inputs
inputs = [tokenizer(chunk, return_tensors="pt") for chunk in chunks]
# print summary
for input in inputs:
output = model.generate(**input)
print(tokenizer.decode(*output, skip_special_tokens=True))
and the output is:
This is a very very long text.
This is a very very long text.
This is a very very long text.
This is a very very long text.
This is a very very long text.
Some intermediate results:
[len(tokenizer(c.strip()).input_ids) for c in chunks]
gives:
[505, 505, 505, 505, 385]
which are well within tokenizer.model_max_length
of 512
.
Do let me know if anything seems weird.