In the paper, the reported accuracy is 82.8, but my implementation achieves only 80.4. Could you provide insights or advice on why this discrepancy might occur? Thank you!
here is my code
import json
import re
import argparse
from tqdm import tqdm
import torch
import os
from pathlib import Path
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
import time
from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
destroy_model_parallel,
)
def preprocess(text):
text = text.strip()
text = text.replace(" [title]“, “. “)
text = re.sub(”\[.*?\]”, “”, text)
text = text.replace(” “,” ")
return text
def load_hellaswag_jsonl(path):
with open(path, ‘r’, encoding=‘utf-8’) as f:
data = [json.loads(line) for line in f]
return data
def load(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
) → LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob(“consolidated.*.pth”))
if not checkpoints:
raise FileNotFoundError(f"No checkpoint files found in {ckpt_dir}")
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
assert rank < len(checkpoints), f"Rank {rank} exceeds available checkpoints."
ckpt_path = checkpoints[rank]
print(f"Rank {rank}: Loading checkpoint from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
local_rank = int(os.environ.get('LOCAL_RANK', 0))
if torch.cuda.is_available():
device = torch.device(f'cuda:{local_rank % torch.cuda.device_count()}')
torch.set_default_dtype(torch.float16)
else:
device = torch.device('cpu')
torch.set_default_dtype(torch.float32)
print(f"Rank {rank}: Using device {device}")
model = Transformer(model_args).to(device)
model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer)
print(f"Rank {rank}: Loaded in {time.time() - start_time:.2f} seconds")
generator.model.eval()
return generator
def map_dataset(example, tokenizer):
ctx = preprocess(example[‘activity_label’] + ": " + example[“ctx_a”] + " " + example[“ctx_b”].capitalize())
endings = [preprocess(ending) for ending in example[‘endings’]]
gold = int(example[‘label’])
ctx_tokens = tokenizer.encode(ctx, bos=True, eos=False)
choices = [tokenizer.encode(ending, bos=False, eos=False) for ending in endings]
return ctx_tokens, choices, gold
def evaluate_accuracy(llama_model, dataset, device, limit=None):
correct = 0
total = min(len(dataset), limit) if limit is not None else len(dataset)
rank = int(os.environ.get("RANK", 0))
for idx, entry in enumerate(tqdm(dataset, desc="Evaluating", disable=(rank != 0))):
if limit is not None and idx >= limit:
break
ctx_tokens, choices, gold = map_dataset(entry, llama_model.tokenizer)
log_probs = []
ctx_length = len(ctx_tokens)
for choice in choices:
input_tokens = ctx_tokens + choice
input_ids = torch.tensor([input_tokens], dtype=torch.long).to(device)
with torch.no_grad():
logits = llama_model.model(input_ids, 0)
log_probs_sequence = torch.nn.functional.log_softmax(logits, dim=-1)
log_prob = 0.0
for i in range(ctx_length, len(input_tokens)):
token_id = input_tokens[i]
log_prob += log_probs_sequence[0, i - 1, token_id].item()
log_probs.append(log_prob)
pred = log_probs.index(max(log_probs))
if pred == gold:
correct += 1
correct_tensor = torch.tensor(correct).to(device)
total_tensor = torch.tensor(total).to(device)
torch.distributed.reduce(correct_tensor, dst=0, op=torch.distributed.ReduceOp.SUM)
torch.distributed.reduce(total_tensor, dst=0, op=torch.distributed.ReduceOp.SUM)
if rank == 0:
accuracy = (correct_tensor.item() / total_tensor.item()) * 100
print(f"Accuracy on HellaSwag: {accuracy:.2f}%")
def main():
parser = argparse.ArgumentParser(description=“Evaluate LLaMA model on HellaSwag dataset”)
parser.add_argument(
"--data-path",
type=str,
default="/home/n7/seongkyoon/hellaswag_val.jsonl",
help="Path to HellaSwag JSONL file"
)
parser.add_argument(
"--ckpt_dir",
type=str,
default="/data/s1/open-source-llm/llama_v1/llama/llama_pretrained/30B",
help="Path to the checkpoint directory (default: /data/s1/open-source-llm/llama_v1/llama/llama_pretrained/30B)"
)
parser.add_argument(
"--tokenizer_path",
type=str,
default="/data/s1/open-source-llm/llama_v1/llama/llama_pretrained/tokenizer.model",
help="Path to the tokenizer model (default: /data/s1/open-source-llm/llama_v1/llama/llama_pretrained/tokenizer.model)"
)
parser.add_argument(
"--limit",
type=int,
default=None,
help="Limit the number of examples to evaluate (default: None, evaluate all)"
)
args = parser.parse_args()
torch.distributed.init_process_group(backend='nccl')
world_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(world_size)
local_rank = int(os.environ.get('LOCAL_RANK', 0))
if torch.cuda.is_available():
device = torch.device(f'cuda:{local_rank % torch.cuda.device_count()}')
torch.set_default_dtype(torch.float16)
else:
device = torch.device('cpu')
torch.set_default_dtype(torch.float32)
print(f"Rank {int(os.environ.get('RANK', 0))}: Using device {device}")
if int(os.environ.get('RANK', 0)) == 0:
print("Loading HellaSwag data...")
raw_data = load_hellaswag_jsonl(args.data_path)
if int(os.environ.get('RANK', 0)) == 0:
print(f"Number of examples to evaluate: {len(raw_data) if args.limit is None else min(len(raw_data), args.limit)}")
if int(os.environ.get('RANK', 0)) == 0:
print("Loading LLaMA model and tokenizer...")
llama_model = load(
ckpt_dir=args.ckpt_dir,
tokenizer_path=args.tokenizer_path,
max_seq_len=1024,
max_batch_size=4
)
if int(os.environ.get('RANK', 0)) == 0:
print("Evaluating model accuracy on HellaSwag...")
evaluate_accuracy(llama_model, raw_data, device, limit=args.limit)
destroy_model_parallel()
torch.distributed.destroy_process_group()
if name == “main”:
main()