Hello, I am trying to use BartForConditionalGeneration
, but the generate
function returns random words. The model I am using is facebook/bart-base
and it has not been fine-tuned at all. Here is the code:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model_name = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=True,
)
max_length=120
model_cfg= BartConfig.from_pretrained(model_name)
model_cfg.max_length = max_length
# model_cfg.use_cuda=True
model_cfg.force_bos_token_to_be_generated = True
model = BartForConditionalGeneration(model_cfg).to(device)
Here is the generation code:
from transformers.models.bart.modeling_bart import shift_tokens_right
ARTICLE_TO_SUMMARIZE ="My friends are cool but they eat too many carbs."
inputs = tokenizer([ARTICLE_TO_SUMMARIZE],
max_length=max_length,
return_tensors='pt',truncation=True).to(device)
summary_ids = model.generate(shift_tokens_right(inputs['input_ids'],tokenizer.bos_token_id,tokenizer.eos_token_id), num_beams=4,max_length=40)
print([tokenizer.decode(g, skip_special_tokens=True) for g in summary_ids])
The output of this is:
[' society society society Steve Steve Steve smoking smoking smoking contiguous contiguous contiguous Canucks Canucks Canucks concurrent concurrent concurrentRandRandRandLessLessLess Providence Providence ProvidenceumatumatumatCombatCombatCombat vicious vicious vicious Kre']
I have been trying to solve different issues with generate like this for nearly the past three day and a half, but have no idea…
Any help is incredibly appreciated