Custom langage modeling/generate words from context

Hi,

I’d like to augment an existing model like t5 or xlnet to generate words but only words from a given context that is different for each sample.

In order to do that i added a layer on top of t5 :

class T5TextGen(nn.Module):

  def __init__(self, base_model, max_seq_length_model):
    super(T5TextGen, self).__init__()
    self.base_model = base_model
    self.word_distrib = nn.Linear(768, max_seq_length_model)

  def forward(self, input_ids, decoder_input_ids, past_key_values=None):
    
      outputs = self.base_model(
            input_ids, 
            decoder_input_ids=decoder_input_ids, past_key_values=past_key_values)
      
      return self.word_distrib(outputs['last_hidden_state']), outputs['past_key_values']

Then i was thinking to get the argmax from the last layer and use it as an index to retrieve the word to predict from the context. And so on for every words.

But i’m a bit lost on how to do that in term of data representation… Should i use masks or not ? Which loss should i use ? Also for text generation process, should i iterate over forward and give the past_key_values from the last prediction, or is there another way to do this ?

If you have any resource where something like this is done, it would be really helpful.

Thanks!