How can I sample with BART for conditional generation?

Hi there,

I’m currently working with BART for conditional generation and would like to generate some good old-fashioned sampled outputs (i.e. nothing fancy and no beam search) for experimentation.

According to the very nice blog post by @patrickvonplaten, this should be possible by providing do_sample=True and top_k=0 to the .generate() method.

However, it seems that the outputs are actually beam_search outputs and not true samples. E.g. output object is BeamSampleEncoderDecoderOutput and not SampleEncoderDecoderOutput, as I would have expected.

Here’s a minimal example for what I’m trying to do. Note, for portability I’m just using the pre-trained model in this example but in my original code I’m loading a fine-tuned model.

import sys
import torch
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

# expects path to local HuggingFace pre-trained model
model_path = sys.argv[1]
num_return_sequences = 10
max_length = 256

# load from pre-trained
model = BartForConditionalGeneration.from_pretrained(model_path)
config = BartConfig.from_pretrained(model_path)
tokenizer = BartTokenizer.from_pretrained(model_path, use_fast=True)    

if torch.cuda.is_available():
    model.cuda()

model.eval()

# list of sentences for decoding    
sents = [
    'In its most basic form, sampling means randomly picking the next word according to its conditional probability distribution.',
    'Taking the example from above, the following graphic visualizes language generation when sampling.',
    'In transformers, we set do_sample=True and deactivate Top-K sampling (more on this later) via top_k=0.']

# decode each input sentence - no batching!
for sent in sents:
    inputs = tokenizer(sent, return_tensors="pt", padding=True, truncation=True, max_length=256)
    # ensure input tensors are on the same device as model
    inputs.to(model.device)
    output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], use_cache=True, max_length=max_length, pad_token_id=tokenizer.pad_token_id, decoder_start_token_id=tokenizer.pad_token_id, do_sample=True, top_k=0, num_return_sequences=num_return_sequences, return_dict_in_generate=True)
    print(type(output))
    
    batch_hyp_strs = tokenizer.batch_decode(output.sequences.tolist(), skip_special_tokens=True)

    print('src:\t', sent)
    for hyp in batch_hyp_strs:
        print(f'\t{hyp}')
    print()

Any clarification on how to get true random samples from pre-trained/fine-tuned BART for conditional generation would be much appreciated!

Environment details

torch==1.8.0
transformers==4.9.0 (installed from source)

Answering my own question here in case anyone ever runs into the same issue.

The argument num_beams used by .generate() defaults to the value specified in the model’s config file (e.g. config.json). If this value is > 1, decoding will be performed with beam search (either regular or sampled).

So, the simple fix is to explicitly set num_beams=1 in the call to .generate(), e.g.:

output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], use_cache=True, max_length=max_length, pad_token_id=tokenizer.pad_token_id, decoder_start_token_id=tokenizer.pad_token_id, do_sample=True, top_k=0, num_beams=1, num_return_sequences=num_return_sequences, return_dict_in_generate=True)

print(type(outputs)) # ==> transformers.generation_utils.SampleEncoderDecoderOutput

Lesson of the day: keep calm and read the docstrings carefully! transformers/generation_utils.py at 91ff480e2693f36b11aaebc4e9cc79e4e3c049da · huggingface/transformers · GitHub :wink: