When I try to run BertForQuestionAnswering
with a Trainer
object, it reaches the end of the eval before throwing KeyError: 'eval_loss'
(full traceback below).
I ran a very vanilla implementation based very closely on the Fine-tuning with custom datasets QA tutorial.
The training and validation both finish, but from the traceback, it seems like there is some problem when reporting results.
Am I missing something that should be there? Is this a bug? Is Trainer
not supported here?
This is transformers
v3.4.0.
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
class MyDataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __getitem__(self, idx):
# self.encodings.keys() = ['input_ids', 'attention_mask', 'start_positions', 'end_positions']
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
def __len__(self):
return len(self.encodings.input_ids)
train_dataset = MyDataset(train_encodings)
val_dataset = MyDataset(val_encodings)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
training_args = TrainingArguments(
output_dir="./tmp/qa_trainer_test",
do_train=True,
do_eval=True,
evaluation_strategy="epoch",
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
learning_rate=3e-5,
num_train_epochs=1,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
trainer.train()
Traceback:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-22-7b137ef43258> in <module>
20 )
21
---> 22 trainer.train()
~/SageMaker/conda_env/my_env/lib/python3.7/site-packages/transformers/trainer.py in train(self, model_path, trial)
790
791 self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
--> 792 self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
793
794 if self.args.tpu_metrics_debug or self.args.debug:
~/SageMaker/conda_env/my_env/lib/python3.7/site-packages/transformers/trainer.py in _maybe_log_save_evalute(self, tr_loss, model, trial, epoch)
843 metrics = self.evaluate()
844 self._report_to_hp_search(trial, epoch, metrics)
--> 845 self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
846
847 if self.control.should_save:
~/SageMaker/conda_env/my_env/lib/python3.7/site-packages/transformers/trainer_callback.py in on_evaluate(self, args, state, control, metrics)
350 def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
351 control.should_evaluate = False
--> 352 return self.call_event("on_evaluate", args, state, control, metrics=metrics)
353
354 def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
~/SageMaker/conda_env/my_env/lib/python3.7/site-packages/transformers/trainer_callback.py in call_event(self, event, args, state, control, **kwargs)
374 train_dataloader=self.train_dataloader,
375 eval_dataloader=self.eval_dataloader,
--> 376 **kwargs,
377 )
378 # A Callback can skip the return of `control` if it doesn't change it.
~/SageMaker/conda_env/my_env/lib/python3.7/site-packages/transformers/utils/notebook.py in on_evaluate(self, args, state, control, metrics, **kwargs)
324 else:
325 values["Step"] = state.global_step
--> 326 values["Validation Loss"] = metrics["eval_loss"]
327 _ = metrics.pop("total_flos", None)
328 _ = metrics.pop("epoch", None)
KeyError: 'eval_loss'