I’m studying the source code of beam search, some implement details make me confused. The code following transformers==4.26.1
:
- Question 1:
# generation/beam_search.py:BeamSearchScorer.process
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (batch_beam_idx,)
else:
beam_index = None
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.group_size:
break
I wander why if beam_token does not belong to top num_beams tokens, it should not be added
? consider that beam size is 4, current sequences is [A, B, C, D, E, F, G, H], where sequence A and E reach the EOS token at current step, this implement will add A to BeamHypotheses
but E is not, and the next beams will be [B, C, D, F], but when length_penalty<0
, the sequence F+[next tokens] cannot be better than E, but E will always be ignored, besides, the extended sequence of B, C, D is not guaranteed to better of worse than E. So, why we ignore the sequence E?
- Question 2:
I think whenself.length_penalty>0
, we cannot conclude that beam search is done.
class BeamHypotheses:
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len**self.length_penalty
ret = self.worst_score >= cur_score
return ret