I am trying to train the encoder-decoder model (EncoderDecoderModel
) using the REINFORCE algorithm. For that, I need to decode using the random sampling method (do_sample=True
) and use the logits
. I can’t find a way to get the logits
for the sampled tokens as, in the training mode, the decoding used is greedy_decoding
. The way to apply random sampling is only through the generate
method but I couldn’t find a way to get the logits
and, also, it doesn’t store the gradients.
Any workaround for it?