I am trying to compute Teacher-Forced Accuracy (TFA) for Hugging Face models, ensuring the following:
- EOS Token Handling: The model should be rewarded for predicting the first EOS token.
- Ignoring Padding: Any padding tokens (beyond the first EOS) should be ignored during accuracy calculation.
- Right-Shifted Input: The inputs are shifted correctly for teacher-forced training.
- List item
Here’s the full code I wrote:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def compute_tfa(model, tokenizer, input_texts):
"""
Computes Teacher-Forced Accuracy (TFA), rewarding the model for correctly predicting
the first EOS token while ignoring predictions for padding tokens.
Parameters:
model: The language model (Hugging Face CausalLM).
tokenizer: The tokenizer corresponding to the model.
input_texts: List of input texts to compute TFA.
Returns:
TFA score as a float.
"""
# Tokenize input texts
tokenizer.pad_token = tokenizer.eos_token # Use EOS as the pad token
inputs = tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True)
input_ids = inputs['input_ids']
# Create right-shifted input by adding the EOS token at the beginning
eos_token_id = tokenizer.eos_token_id
right_shifted_input_ids = torch.cat([
torch.full((input_ids.shape[0], 1), eos_token_id, dtype=torch.long), # Add EOS token
input_ids[:, :-1]
], dim=1)
# Perform a forward pass with the right-shifted inputs
with torch.no_grad():
outputs = model(input_ids=right_shifted_input_ids)
logits = outputs.logits # Shape: (batch_size, sequence_length, vocab_size)
# Compute predictions
predicted_token_ids = torch.argmax(logits, dim=-1) # Shape: (batch_size, sequence_length)
# Find the first EOS position in each sequence
eos_positions = (input_ids == eos_token_id).int().argmax(dim=1) # Shape: (batch_size,)
# Mask to ignore tokens after the first EOS
sequence_lengths = input_ids.size(1)
mask = torch.arange(sequence_lengths).unsqueeze(0).to(input_ids.device)
mask = mask < eos_positions.unsqueeze(1)
# Include the first EOS token in the mask
mask.scatter_(1, eos_positions.unsqueeze(1), 1)
# Apply the mask to filter predictions and labels
filtered_predictions = predicted_token_ids[mask]
filtered_labels = input_ids[mask]
# Compute accuracy
correct_predictions = (filtered_predictions == filtered_labels).float()
accuracy = correct_predictions.mean().item()
return accuracy
def main():
# Define models and their URLs
models_and_urls = {
"google/gemma-2-2b": "https://huggingface.co/google/gemma-2-2b",
"meta-llama/Llama-3.1-8B": "https://huggingface.co/meta-llama/Llama-3.1-8B",
"gpt2": "https://huggingface.co/gpt2"
}
# Define input texts
input_texts = [
"The quick brown fox jumps over the lazy dog.",
"Artificial Intelligence is transforming the world of science."
]
# Test each model
for model_name, model_url in models_and_urls.items():
print(f"Testing model: {model_name} ({model_url})")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Compute TFA
tfa_score = compute_tfa(model, tokenizer, input_texts)
print(f"Teacher-Forced Accuracy (TFA) for {model_name}: {tfa_score:.4f}\n")
if __name__ == "__main__":
main()
What I Need Help With:
-
EOS Token Masking: Is the masking logic I implemented for ignoring tokens after the first EOS correct? Specifically, I used:
mask = torch.arange(sequence_lengths).unsqueeze(0).to(input_ids.device) mask = mask < eos_positions.unsqueeze(1) mask.scatter_(1, eos_positions.unsqueeze(1), 1)
Is this the best way to ensure only tokens up to and including the first EOS are considered?
-
Right-Shifted Input: I prepend the EOS token to the input like this:
right_shifted_input_ids = torch.cat([ torch.full((input_ids.shape[0], 1), eos_token_id, dtype=torch.long), input_ids[:, :-1] ], dim=1)
Is this a standard way to handle the right-shift for teacher-forced evaluation?
-
Generalization: The code is designed to evaluate multiple models, such as
google/gemma-2-2b
,meta-llama/Llama-3.1-8B
, andgpt2
. Are there any additional considerations or best practices I should follow for TFA computation across diverse models? -
Performance Optimization: Is there a more efficient way to compute the mask and apply it to the predictions and labels? My current method seems to work but might be suboptimal for larger datasets.
Any feedback or suggestions would be greatly appreciated!