Implementing the REINFORCE algorithm for encoder-decoder model

I am trying to train the encoder-decoder model (EncoderDecoderModel) using the REINFORCE algorithm. For that, I need to decode using the random sampling method (do_sample=True) and use the logits. I can’t find a way to get the logits for the sampled tokens as, in the training mode, the decoding used is greedy_decoding. The way to apply random sampling is only through the generate method but I couldn’t find a way to get the logits and, also, it doesn’t store the gradients.

Any workaround for it?

Hi, did you find a way to do this?