Stopping `model.generate()` based on custom token

Hello everyone, I’ve managed to train a huggingface model that generates coherent sequences based on my training data and am using generate to create these new sequences. This has worked well enough so far however I need to stop sequence generation based on the count of a particular token that denotes the start of a subsequence in my domain. Is there a way to leverage the generate() method to do this? ie rather than generate based on length generate until n number of a particular token are generated.

1 Like

I found that the best way to do this is by directly calling the model with the necessary inputs rather than using the generate method, and to build logic around this that checks the number of a particular token in the resulting sequence and stops once its reached.

2 Likes

Can you share your code snippet for doing this ? I want to implement a similar custom generate function but can’t parse through the entire codebase in a short time

1 Like