Custom loss: does this word exist


I am trying to have a custom loss, that checks whether certain words exist in the generated output.
This is giving me multiple problems:

  • I cannot be sure that each word I am looking for is represented as one token
  • I cannot just decode the output and look for the words, since this breaks the gradient tape

What would be the best way to identify how many of a certain set of tokens is present in the model output while keeping the gradient intact?