Nice infrence trick

so I am working with nllb and what I noticed is that it tends to do this thing of repeating the same tokens over and over consecutively. now what people usually do here is pass no_repeat_ngrams but that checks for ngrams on the entire text and thats not the issue here. names for instance should probably repeate.

here is my solution:
just check for the first few tokens if u see a repeat and act just on that

class StopRepeats(LogitsProcessor):
#stop repeating values of ngram_size or more inside the context
#for instance abcabc is repeating twice has an ngram_size of 3 and fits in a context of 6
def init(self, count,ngram_size,context):
self.count = count
self.ngram_size=ngram_size
self.context = context

def __call__(self, input_ids, scores):#encoder_input_ids
    if input_ids.size(1) > self.context:
        input_ids = input_ids[:, -self.context:]
    
    for step in range(self.ngram_size, self.context // 2+ 1):                
        
        cuts=[input_ids[:,i:i+step] for i in range(len(input_ids[0])-step,0,-step)]
        cuts=cuts[:self.count-1]
        if(len(cuts)!=self.count-1):
            continue

        matching = torch.ones(input_ids.shape[0], dtype=torch.bool,device=input_ids.device)
        for cut in cuts[1:]:
           matching&= (cut==cuts[0]).all(dim=1)

        scores[matching,cuts[0][matching,-1]]=float("-inf")
           

    return scores

I thought maybe I pull request this into hf so people can use it but idk if thats the type of thing u pull request