How to generate a sequence using inputs_embeds instead of input_ids?

Hello, I am struggling with generating a sequence of tokens using model.generate() with inputs_embeds.
For my research, I have to use inputs_embeds (word embedding vectors) instead of input_ids (token indices) as an input to the GPT2 model.
I want to employ model.generate() which is a convenient tool for generating a sequence of tokens, but there is no argument for inputs_embeds. I tried to edit " transformers.generation_utils", but it was not easy to figure out which lines I should change.

Is there any idea that I can easily generate tokens with default settings for hyper-paremeters as in model.generate()? If there is any idea, help me please.


Were you able to figure something out for this ? My needs are similar. I’m using BART model

1 Like


consider you have the tensor inputs_embeds which I believe will be in the shape of (batch_size, seq_length, dim), or If you have a hidden_state in the shape of (batch_size, dim) just unsqueeze(dim=1) it to become (batch_size,1,dim)

Then smoothly u can achieve the desired output by:

  • create an attention mask (batch_size, seq_len)
  • create decoder_input_ids (batch_size, 1)

tracing these two parameters, leads to two functions below:

attention mask you are already familiar with it.

decoder_input_ids it’s ones too (multiplied by config.decoder_start_token_ids).

wrapping it all:

generator = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
inputs_embeds = # a 3D tensor, [batch, seq_length, dim]
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.long)
decoder_input_ids = torch.ones((inputs_embeds.shape[0], 1), dtype=torch.long)*generator.config.decoder_start_token_id
output_ids = generator.generate(attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, inputs_embeds=inputs_embeds, max_length=100,num_beams=4)


Hey thanks a lot for the reply and it works nice !
Slight clarification, would I need to set decoder_input_ids when I’m training the said model using inputs_embeds as well ?

For BART, you can use encoder_outputs, which you should get from the encoder part of BART model.