Hi there,
I’m currently working with BART for conditional generation and would like to generate some good old-fashioned sampled outputs (i.e. nothing fancy and no beam search) for experimentation.
According to the very nice blog post by @patrickvonplaten, this should be possible by providing do_sample=True
and top_k=0
to the .generate()
method.
However, it seems that the outputs are actually beam_search
outputs and not true samples. E.g. output object is BeamSampleEncoderDecoderOutput
and not SampleEncoderDecoderOutput
, as I would have expected.
Here’s a minimal example for what I’m trying to do. Note, for portability I’m just using the pre-trained model in this example but in my original code I’m loading a fine-tuned model.
import sys
import torch
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
# expects path to local HuggingFace pre-trained model
model_path = sys.argv[1]
num_return_sequences = 10
max_length = 256
# load from pre-trained
model = BartForConditionalGeneration.from_pretrained(model_path)
config = BartConfig.from_pretrained(model_path)
tokenizer = BartTokenizer.from_pretrained(model_path, use_fast=True)
if torch.cuda.is_available():
model.cuda()
model.eval()
# list of sentences for decoding
sents = [
'In its most basic form, sampling means randomly picking the next word according to its conditional probability distribution.',
'Taking the example from above, the following graphic visualizes language generation when sampling.',
'In transformers, we set do_sample=True and deactivate Top-K sampling (more on this later) via top_k=0.']
# decode each input sentence - no batching!
for sent in sents:
inputs = tokenizer(sent, return_tensors="pt", padding=True, truncation=True, max_length=256)
# ensure input tensors are on the same device as model
inputs.to(model.device)
output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], use_cache=True, max_length=max_length, pad_token_id=tokenizer.pad_token_id, decoder_start_token_id=tokenizer.pad_token_id, do_sample=True, top_k=0, num_return_sequences=num_return_sequences, return_dict_in_generate=True)
print(type(output))
batch_hyp_strs = tokenizer.batch_decode(output.sequences.tolist(), skip_special_tokens=True)
print('src:\t', sent)
for hyp in batch_hyp_strs:
print(f'\t{hyp}')
print()
Any clarification on how to get true random samples from pre-trained/fine-tuned BART for conditional generation would be much appreciated!
Environment details
torch==1.8.0
transformers==4.9.0 (installed from source)