Context:
I have fine-tuned my model on squad_v2
dataset using the run_qa.py
script. Now, I am trying to use the same script to only evaluate the fine-tuned model on dattatreya303/covid-qa-tts
dataset. This is a dataset created by me, adapted from covid_qa_deepset
by introducing train/test/val splits. But using it in the following command gives a ValueError (as shown below).
Questions
I do not understand the error message “Predictions and/or references don’t match the expected format.” The expected format matches the examples shown in the error message. AFAIK the data is in the correct squad format. Have I missed anything in the loading script of the dattatreya303/covid-qa-tts
dataset card?
Or am I missing an argument in the run_qa.py
command?
Any help is appreciated!
Command:
!python run_qa.py \
--model_name_or_path ./ft-roberta-squadv2/checkpoint-31500/ \
--dataset_name dattatreya303/covid-qa-tts \
--do_eval \
--per_device_eval_batch_size 12 \
--learning_rate 1e-5 \
--max_seq_length 384 \
--doc_stride 128 \
--version_2_with_negative \
--output_dir ./ft-roberta-squadv2-eval-cqa/
Stacktrace:
Traceback (most recent call last):
File "run_qa.py", line 684, in <module>
main()
File "run_qa.py", line 641, in main
metrics = trainer.evaluate()
File "/content/transformers/examples/pytorch/question-answering/trainer_qa.py", line 58, in evaluate
metrics = self.compute_metrics(eval_preds)
File "run_qa.py", line 603, in compute_metrics
return metric.compute(predictions=p.predictions, references=p.label_ids)
File "/usr/local/lib/python3.7/dist-packages/evaluate/module.py", line 432, in compute
self.add_batch(**inputs)
File "/usr/local/lib/python3.7/dist-packages/evaluate/module.py", line 512, in add_batch
raise ValueError(error_msg) from None
ValueError: Predictions and/or references don't match the expected format.
Expected format: {'predictions': {'id': Value(dtype='string', id=None), 'prediction_text': Value(dtype='string', id=None), 'no_answer_probability': Value(dtype='float32', id=None)}, 'references': {'id': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}},
Input predictions: [{'id': 283, 'prediction_text': 'Betacoronavirus', 'no_answer_probability': 0.0}, {'id': 431, 'prediction_text': 'double-stranded', 'no_answer_probability': 0.0}, {'id': 4187, 'prediction_text': 'lapses in infection prevention and control (IPC) in healthcare settings', 'no_answer_probability': 0.0}, ..., {'id': 2771, 'prediction_text': 'the expected number of secondary infections', 'no_answer_probability': 0.0}, {'id': 3254, 'prediction_text': 'Persistent high fever, dyspnea and rapid progression to respiratory failure within 2 weeks', 'no_answer_probability': 0.0}, {'id': 3628, 'prediction_text': 'to identify published studies in accordance with the Preferred Reporting Items for Systematic Reviews and Meta-Analyses (PRISMA) guidelines', 'no_answer_probability': 0.0}],
Input references: [{'id': 283, 'answers': {'text': ['Betacoronavirus'], 'answer_start': [1723]}}, {'id': 431, 'answers': {'text': ['double-stranded ribonucleic acid'], 'answer_start': [4111]}}, {'id': 4187, 'answers': {'text': ['to lapses in infection prevention and control (IPC) in healthcare settings'], 'answer_start': [1932]}}, ..., {'id': 2771, 'answers': {'text': ['2.7-3.4 or 2-4 in Hong Kong'], 'answer_start': [15034]}}, {'id': 3254, 'answers': {'text': ['Persistent high fever, dyspnea and rapid progression to respiratory failure within 2 weeks, together with bilateral consolidations and infiltrates at the same time, are the most frequent clinical manifestations'], 'answer_start': [13068]}}, {'id': 3628, 'answers': {'text': ['to identify published studies examining the diagnosis, therapeutic drugs and vaccines for Severe Acute Respiratory Syndrome (SARS), Middle East Respiratory Syndrome (MERS) and the 2019 novel coronavirus (2019-nCoV), in accordance with the Preferred Reporting Items for Systematic Reviews and Meta-Analyses (PRISMA) guidelines.'], 'answer_start': [4552]}}]