Calculating VLM and VLA Loss

I am trying to calculate the perplexity of the OpenVLA model and its Prismatic VLM backbone.

I was able to get a reasonable number by using the compute_tranisition_scores() to get the log prob.

scores = vla.compute_transition_scores(
    # sequences=discrete_action_tensor,
    sequences=generated_ids[:, -7:], 
    scores=generated_scores,
    normalize_logits=True,
)

print(scores)
print(scores.shape)
output_length = np.sum(scores.cpu().numpy() < 0, axis=1)
length_penalty = vla.generation_config.length_penalty
reconstructed_scores = scores.cpu().sum(axis=1) / (output_length**length_penalty)
print(reconstructed_scores)
perplexity = torch.exp(-reconstructed_scores).item()

I am expecting the loss from the model will be the same value. So I tried to calculate it as follows:

batched_pixel_values = processor.image_processor(image, return_tensors="pt").pixel_values.to(device=vla1.device,dtype=model_dtype)

input_ids_for_loss = generated_ids[:, -7:-1]
label_ids_for_loss = generated_ids[:, -6:]

output = vla1(
    input_ids=input_ids_for_loss,
    attention_mask=torch.ones_like(input_ids_for_loss),
    pixel_values=batched_pixel_values,
    labels=label_ids_for_loss,
    return_dict=True
)

loss = output.loss.item()

The values are very different though. The average log prob for method 1 is -0.3929 for my example and the loss is around 4.

Clearly, the loss seems wrong. However, I am not sure what the problem is since I am using the next table as the label on the outputted sequence. I was thinking if theres an off-by-one but I did preprocess the BOS and EOS tokens. Please help :face_holding_back_tears:

Thank you guys in advance.

1 Like