How can I do text Summarization using ProphetNet

ProphetNet automatically shifts the tokens so you call compute the loss as follows:

prophetnet = ProphetNetForConditionalGeneration.from_pretrained(...) 

loss = prophetnet(input_ids=tokenized_article, decoder_input_ids=tokenized_summary, labels=tokenized_summary)
1 Like