You can compute the loss outside of your model since it returns the logits, and apply any function you like.
If you question was related to the Trainer
, you should definte your subclass with a compute_loss
method. There is an example in the documentation (scroll a bit down).