@Neuroinformatica sorry for the delay. Let me see if I can help out.
Maybe one thing to try is to declare the path to your evaluation data in the Trainer
object. Something like this:
trainer = Trainer(
my_model,
data_collator=data_collator,
tokenizer=my_tokenizer,
eval_dataset=eval_data_path
)
trainer.evaluate()
I’m not sure how much of a difference that will make, but it’s worth a shot to knock out some low hanging fruit.
I could be wrong about this, but the error looks like it’s trying to access a key in a dictionary that equals 0. My intuition is telling me that somewhere in the data collation or loading the data, there needs to be a translation between the labels and their index location in a list. So if I had a list of labels ['red', 'orange', 'blue']
I would need to make a translation to their index value that would look like [0, 1, 2]
. That’s the reason why I think the KeyError: 0
error is occurring. That being said, I’m not too familiar with using bert for SQUAD.
Another option to consider is to pass the labels list into the trainer
object. To do this, you would also need to created a TrainingArguments
object. Something like this:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments
import torch
from transformers import default_data_collator
import json
# Model from HuggingFace
model_checkpoint = 'mrm8488/bert-italian-finedtuned-squadv1-it-alfa'
# Import tokenizer
my_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# Import model
my_model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
# Dataset for evaluation
eval_data_path = '/content/drive/MyDrive/BERT/SQuAD_files/result.json'
with open(eval_data_path) as json_file:
data = json.load(json_file)
data_collator = default_data_collator
my_train_args = TrainingArguments(
output_dir = '/path/to/where/you/want/the/output',
label_names = list_of_label_names
)
trainer = Trainer(
my_model,
data_collator=data_collator,
tokenizer=my_tokenizer,
args=my_train_args
)
trainer.evaluate(data)
Like I said, I’m not super familiar with using bert for QA tasks so maybe @sgugger has some better insight.