I get very different WER using the Trainer and a loop I made myself
Using Trainer
training_args = TrainingArguments(
output_dir=model_filename_hugging,
group_by_length=True,
per_device_train_batch_size=8,
gradient_accumulation_steps=1,
evaluation_strategy="steps",
num_train_epochs=1,
fp16=True,
gradient_checkpointing=True,
save_steps=500,
eval_steps=500,
logging_steps=500,
learning_rate=1e-4,
weight_decay=0.005,
warmup_steps=1000,
save_total_limit=2,
push_to_hub=False,
)
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics(processor),
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=processor.feature_extractor,
)
result_dict = trainer.evaluate(eval_dataset=test_dataset)
print(result_dict)
It gives:
{'eval_loss': 0.4389318525791168, 'eval_wer': 0.3150713252015712, 'eval_runtime': 12.7125, 'eval_samples_per_second': 132.153, 'eval_steps_per_second': 16.519}
wer = 31.5% with Trainer
with compute_metrics
:
wer_metric = load_metric("wer")
def compute_metrics(processor):
def __call__(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
return __call__
Using a custom loop:
from datasets import load_metric
print(f"batch decoding")
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
eval_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=8)
metric = load_metric('wer')
model.eval()
for step, batch in tqdm.tqdm(enumerate(eval_dataloader)):
with torch.no_grad():
input_values = torch.tensor(np.array(batch['input_values'])).cuda()
logits = model(input_values).logits
pred_ids = torch.argmax(logits, dim=-1)
pred_str = processor.batch_decode(pred_ids)
batch["labels"][batch["labels"] == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(batch["labels"], group_tokens=False)
metric.add_batch(predictions=pred_str, references=label_str)
print('wer computation')
eval_metric = metric.compute()
print(f"wer: {eval_metric}")
which gives: wer=20.4%
Can anyone help me to find the bug?
(I use transformers 4.16.2 and datasets 1.18.3)