so I am working with nllb and what I noticed is that it tends to do this thing of repeating the same tokens over and over consecutively. now what people usually do here is pass no_repeat_ngrams but that checks for ngrams on the entire text and thats not the issue here. names for instance should probably repeate.
here is my solution:
just check for the first few tokens if u see a repeat and act just on that
class StopRepeats(LogitsProcessor):
#stop repeating values of ngram_size or more inside the context
#for instance abcabc is repeating twice has an ngram_size of 3 and fits in a context of 6
def init(self, count,ngram_size,context):
self.count = count
self.ngram_size=ngram_size
self.context = contextdef __call__(self, input_ids, scores):#encoder_input_ids if input_ids.size(1) > self.context: input_ids = input_ids[:, -self.context:] for step in range(self.ngram_size, self.context // 2+ 1): cuts=[input_ids[:,i:i+step] for i in range(len(input_ids[0])-step,0,-step)] cuts=cuts[:self.count-1] if(len(cuts)!=self.count-1): continue matching = torch.ones(input_ids.shape[0], dtype=torch.bool,device=input_ids.device) for cut in cuts[1:]: matching&= (cut==cuts[0]).all(dim=1) scores[matching,cuts[0][matching,-1]]=float("-inf") return scores
I thought maybe I pull request this into hf so people can use it but idk if thats the type of thing u pull request