T5 user defined loss function

Hi,

Just a tip to save you some hassle in the event that you did not already know what I’m about to say.

You’re going to hit a snag in your idea here if you try to pass gradients from this new loss, but of course it is fine for a logging metric.

Gradients cannot flow through a sampling method such as arg max, beam search, or nucleus sampling because the function is non-differentiable. If you train your model with this loss, it will have no bearing on your results.

loss = diversity_loss + lm_loss
loss.backward() # gradients for diversity_loss will all be zero, but your model will still train, so be careful, it is not impacting your training whatsoever!
1 Like