Calculating perplexity from hidden_states

Hi all, I am trying to run ray tune for my masked language model, I want to find the best hyperparameters that will minimize perplexity of the model. I am not able to figure out how to calculate perplexity using the model’s hidden_states, which is returned as EvalPrediction.predictions. Any help will be greatly appreciated. Thank you!

following code snippet show the training.

model_checkpoint = "distilroberta-base"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint).to('cuda')

def compute_custom_metric(eval_pred):
    # following will print (3387, 32, 50265) (beach_size * max_output_len * vocal_size)
    print(eval_pred.predictions.shape)
    # following will print (3387, 32) (batch_size * max_output_len)
    print(eval_pred.label_ids.shape)
    return {'custom_metric': 0}

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = train,
    eval_dataset = validation,
    tokenizer = tokenizer,
    data_collator = data_collator,
    compute_metrics = compute_custom_metric
)

trainer.evaluate()

After extensive searching finally found the solution. Below function calculates perplexity after every epoch. FYI I have also added my training arguments.

def compute_custom_metric(pred):
    logits = torch.from_numpy(pred.predictions)
    labels = torch.from_numpy(pred.label_ids)
    loss = F.cross_entropy(logits.view(-1, tokenizer.vocab_size), labels.view(-1))
    return {'perplexity': math.exp(loss), 'calculated_loss': loss}

training_args = TrainingArguments(
    output_dir='./some_results',
    evaluation_strategy = "epoch",
    num_train_epochs=3,
    learning_rate=1e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./some_logs',
    logging_steps=logging_steps,
    seed=seed,
    fp16=True,
    eval_accumulation_steps=50,
)
1 Like

@qpazuzu
Please could you explain deeper how using F.cross_entropy(logits.view(-1, tokenizer.vocab_size), labels.view(-1)) the cross entropy leads to the perplexity?

As I understand, this part is supposed to give you \sum log [p(w_{i}| w_{<i})]. How then does F.cross_entropy give that ?