I’m studying the source code of beam search, some implement details make me confused. The code following
- Question 1:
# generation/beam_search.py:BeamSearchScorer.process for beam_token_rank, (next_token, next_score, next_index) in enumerate( zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) ): batch_beam_idx = batch_idx * self.group_size + next_index # add to generated hypotheses if end of sentence if (eos_token_id is not None) and (next_token.item() in eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size if is_beam_token_worse_than_top_num_beams: continue if beam_indices is not None: beam_index = beam_indices[batch_beam_idx] beam_index = beam_index + (batch_beam_idx,) else: beam_index = None beam_hyp.add( input_ids[batch_beam_idx].clone(), next_score.item(), beam_indices=beam_index, ) else: # add next predicted token since it is not eos_token next_beam_scores[batch_idx, beam_idx] = next_score next_beam_tokens[batch_idx, beam_idx] = next_token next_beam_indices[batch_idx, beam_idx] = batch_beam_idx beam_idx += 1 # once the beam for next step is full, don't add more tokens to it. if beam_idx == self.group_size: break
I wander why
if beam_token does not belong to top num_beams tokens, it should not be added? consider that beam size is 4, current sequences is [A, B, C, D, E, F, G, H], where sequence A and E reach the EOS token at current step, this implement will add A to
BeamHypotheses but E is not, and the next beams will be [B, C, D, F], but when
length_penalty<0, the sequence F+[next tokens] cannot be better than E, but E will always be ignored, besides, the extended sequence of B, C, D is not guaranteed to better of worse than E. So, why we ignore the sequence E?
- Question 2:
I think when
self.length_penalty>0, we cannot conclude that beam search is done.
class BeamHypotheses: def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: if len(self) < self.num_beams: return False elif self.early_stopping: return True else: cur_score = best_sum_logprobs / cur_len**self.length_penalty ret = self.worst_score >= cur_score return ret