consider you have the tensor inputs_embeds
which I believe will be in the shape of (batch_size, seq_length, dim)
, or If you have a hidden_state
in the shape of (batch_size, dim)
just unsqueeze(dim=1)
it to become (batch_size,1,dim)
Then smoothly u can achieve the desired output by:
- create an
attention mask
(batch_size, seq_len)
- create
decoder_input_ids
(batch_size, 1)
tracing these two parameters, leads to two functions below:
attention mask you are already familiar with it.
decoder_input_ids
it’s ones
too (multiplied by config.decoder_start_token_ids).
wrapping it all:
generator = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
inputs_embeds = # a 3D tensor, [batch, seq_length, dim]
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.long)
decoder_input_ids = torch.ones((inputs_embeds.shape[0], 1), dtype=torch.long)*generator.config.decoder_start_token_id
output_ids = generator.generate(attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, inputs_embeds=inputs_embeds, max_length=100,num_beams=4)