I am working on an encoder decoder model which uses a fine tuned RoBERTa classifier as the encoder and GPT2 as the decoder. Before passing the encoder context to the decoder, I am mixing it with some context from a different domain. This mixing module is a simple NN. Hence, I now want to pass these transformed hidden states to the GPT2 decoder to do decoding, and I will train the decoder and the mixer only, not the encoder. How can I pass these transformed hidden states to the GPT2 decoder instead of the input_ids
or inputs_embeds
? The shape of my transformed hidden states is (n_layers, batch_size, sequence_length, hidden_size)
where I am currently using batch_size=1
, and the sequence_length
is 1 because I took only the [CLS]
token hidden states of the encoder. Any help will be appreciated.