@dipanjanSâs code snippet is a good option using NLTK
. Here is an alternative âpure transformersâ solution:
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
long_text = "This is a very very long text. " * 300
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
# tokenize without truncation
inputs_no_trunc = tokenizer(long_text, max_length=None, return_tensors='pt', truncation=False)
# get batches of tokens corresponding to the exact model_max_length
chunk_start = 0
chunk_end = tokenizer.model_max_length # == 1024 for Bart
inputs_batch_lst = []
while chunk_start <= len(inputs_no_trunc['input_ids'][0]):
inputs_batch = inputs_no_trunc['input_ids'][0][chunk_start:chunk_end] # get batch of n tokens
inputs_batch = torch.unsqueeze(inputs_batch, 0)
inputs_batch_lst.append(inputs_batch)
chunk_start += tokenizer.model_max_length # == 1024 for Bart
chunk_end += tokenizer.model_max_length # == 1024 for Bart
# generate a summary on each batch
summary_ids_lst = [model.generate(inputs, num_beams=4, max_length=100, early_stopping=True) for inputs in inputs_batch_lst]
# decode the output and join into one string with one paragraph per summary batch
summary_batch_lst = []
for summary_id in summary_ids_lst:
summary_batch = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_id]
summary_batch_lst.append(summary_batch[0])
summary_all = '\n'.join(summary_batch_lst)
print(summary_all)
# output (would of course make more sense on a sensible input):
This is a very very long text. This is avery very long texts. This has been a very long day for me. I'm going to have to take a break from this. I've got a lot of work to do. I'll be back in a few days.
This is a very very long text. This is avery very long texts. This has been a very long day for me. I'm going to have to take a break from this. I've got a lot of work to do. I'll be back in a few days.
This is a very very long text. This is avery very long texts. This has been a very, very long day. This will be a very long, very, long night. I hope you enjoy it. I will be back in a week or so with a new text.
The main advantage of this approach is that it uses the tokenization directly from the transformers tokenizer instead of an external tokenizer like NLTK. Keep in mind that most transformer models use different sub-word tokenizers, while NLTK probably uses a word-level tokenizer (see explanation here). This means that NLTK will split a string like âI have a new GPU!â into 6 tokens (one per word + punctuation), while e.g. BERTâs tokenizer will split it into 7 (['i', 'have', 'a', 'new', 'gp', '##u', '!']
), because it splits rare words into sub-words (e.g. GPU). With the âpure transformersâ approach you can be sure to really get the exact maximum number of tokens.
The disadvantage is that there is no sentence boundary detection. You can theoretically solve that with the NLTK (or SpaCy) approach and splitting sentences. But the token threshold should probably be set below 1024 words (maybe 900?), because 1024 NLTK word tokens translate into more than 1024++ sub-word tokens. Otherwise, the text gets truncated again and you effectively delete parts of your text.
I feel like summarising texts above 1024 tokens is probably a common use case and enabling this kind of âlong text summarisationâ could be a very useful feature for the summarisation pipeline. Could this maybe be something that could be added to the pipeline with e.g. a keyword argument like âsummarise_long_text=Trueâ?
I dont know the internals of the pipeline well enough to know if this would be an easy addition or too complicated @sshleifer (Also please correct me if Iâm wrong about the code and explanation above, Iâm also new to this)
(Btw, look at the input and the output⌠is Bart getting lazy when the text is too long and monotonous?
)