About Prophetnet model n-gram loss calculation

    def _compute_loss(self, logits, labels, ignore_index=-100):
        expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)

        for i in range(self.config.ngram):
            if i > 0 and self.disable_ngram_loss:
                break
            expend_targets[i, :, :] = labels

I have question about calculating n-gram loss calculation.

following the compute_loss function, the shape of expend_targets represents [n_gram, batch, seq_len]
and the logits has shape [n_gram, batch, seq_len, vocab_size]

But reference code represents that expend_target copies same labels along first dim(n_gram).
I don’t understand why each n_gram logits have the same target(labels) ids.
I think there should be something shifting process for predicting future n-grams.

actually, it might be my misunderstanding.
Is there anyone who can explain this for me?