I am trying to calculate the perplexity of the OpenVLA model and its Prismatic VLM backbone.
I was able to get a reasonable number by using the compute_tranisition_scores() to get the log prob.
scores = vla.compute_transition_scores(
# sequences=discrete_action_tensor,
sequences=generated_ids[:, -7:],
scores=generated_scores,
normalize_logits=True,
)
print(scores)
print(scores.shape)
output_length = np.sum(scores.cpu().numpy() < 0, axis=1)
length_penalty = vla.generation_config.length_penalty
reconstructed_scores = scores.cpu().sum(axis=1) / (output_length**length_penalty)
print(reconstructed_scores)
perplexity = torch.exp(-reconstructed_scores).item()
I am expecting the loss from the model will be the same value. So I tried to calculate it as follows:
batched_pixel_values = processor.image_processor(image, return_tensors="pt").pixel_values.to(device=vla1.device,dtype=model_dtype)
input_ids_for_loss = generated_ids[:, -7:-1]
label_ids_for_loss = generated_ids[:, -6:]
output = vla1(
input_ids=input_ids_for_loss,
attention_mask=torch.ones_like(input_ids_for_loss),
pixel_values=batched_pixel_values,
labels=label_ids_for_loss,
return_dict=True
)
loss = output.loss.item()
The values are very different though. The average log prob for method 1 is -0.3929 for my example and the loss is around 4.
Clearly, the loss seems wrong. However, I am not sure what the problem is since I am using the next table as the label on the outputted sequence. I was thinking if theres an off-by-one but I did preprocess the BOS and EOS tokens. Please help
Thank you guys in advance.