I don’t know whether gumbel-softmax can be for text generation or not, but there is the paper.
As for implementation, create an dist = torch.distributions.gumbel.Gumbel(0.,1.)
and add gumbel noise to the output logits logits = T5(...)[0]
and new_logits = logits + self.gumbel_dist.sample(logits.shape)
. You could also see my code.
1 Like