Using the decoder half of BART for causal generation

My rudimentary understanding of BART is that it’s basically a BERT encoder feeding into a GPT-2 decoder. Is there a simple way to take a fine-tuned BART model and use just its decoder for text generation?

I’ve played with model.generate(input_ids=None, decoder_input_ids=my_tokenized_prompt,...) and the output looks reasonable enough. But 1) I don’t know if it’s actually giving me an accurate sense of what the decoder has learned and 2) I assume that’s an incredibly inefficient way to accomplish this task, and there must be a better option.

Is there a smarter way?

If you only want to use the decoder of BART, you can do so by simply using the BartDecoder class. So your code could look something like:

from transformers.models.bart.modeling_bart import BartDecoder
model = BartDecoder.from_pretrained("facebook/bart-base")

But note that BART is a seq2seq (encoder-decoder) model, it has been pre-trained in an encoder-decoder set-up, so the best results will probably be obtained by using both the encoder and decoder. But of course you can still use only the decoder if you want :slight_smile:

1 Like

Thanks @nielsr, I’ll check that class out!

And I understand that I’ll get best results using the model as-trained. But I’m messing around with AI in creative projects, so “best” isn’t as important to me as “interesting” at the moment.

In the case of the project that prompted my original question, I had fine-tuned a BART model on song lyrics and song titles – first in the way you’d expect, to generate a title to a given set of lyrics, and then with the columns swapped, to generate lyrics given a title. And now I’m looking for the best way to use that trained decoder to generate lyrics continuing from a prompt, without feeding the model a title.

And this question will also come into play on a future project I’m planning, which needs a new language model for both CLM and seq2seq tasks. If I can get by with just training the language model once, it’ll save me a ton of time.

1 Like

Ok makes sense! In that case, suppose that you have already fine-tuned a BART model, then you can access its decoder as follows:

model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
decoder = model.model.decoder
1 Like

I have a question about decoder prompting. Can we do model.generate(..., decoder_input_ids=[], ...)? I don’t know but I’ve been looking at the documentation and didn’t find it.