How to calculate perplexity from the `generate` function?

Hi,
I am trying to calculate the perplexity from the generate function. I use beam search as the decoding strategy, but I would like to get the perplexity for all outputs of the third sentence (or maybe other, not the first one).

To calculate the perplexity, I need first calculate the loss, but I didn’t find a way to extract the logits from the generate function with beam search. I found that the scores are the “Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.” According to this post: [Announcement] GenerationOutputs: Scores, Attentions and Hidden States now available as outputs to generate, scores now correspond to all processed lm head logits + the current beam_scores for each output token. So I am confused how can I extract the logits to calculate the loss or calculate the perplexity directly from generate function.

any thoughts @patrickvonplaten ?

I have made a function for calculating ppl for one generated sentence:


def calculate_ppl(scores, sequence, rank):
    """
    calculate_ppl calculates the perplexity for one sequence

    Args:
        scores (Tuple[Tensor]): generation scores
        sequence (Tensors): sequence of tokens
        rank (int): rank for the sequence according to sequence score

    Returns:
        float: ppl for one sequence
    """
    log_probs = [torch.max(score[rank]).item() for score in scores]
    ppl = math.exp(-1 * (sum(log_probs) / (sequence.shape[1]-1)))
    return ppl

But I am not sure this is correct, because the ppl is extreme low for my case.