I’m finetuning the zero-shot facebook/bart-large-mnli
model.
These is the new metric I used (from here):
import numpy as np
import evaluate
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
This is how I train it:
training_args = TrainingArguments(
output_dir=model_directory, # output directory
num_train_epochs=1, # total number of training epochs - 3
per_device_train_batch_size=1, # batch size per device during training - 16
per_device_eval_batch_size=2, # batch size for evaluation - 64
warmup_steps=50, # number of warmup steps for learning rate scheduler - 500
weight_decay=0.01, # strength of weight decay
evaluation_strategy="epoch"
)
model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli") # , num_labels=len(label_to_int), ignore_mismatched_sizes=True
trainer = Trainer(
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
compute_metrics=compute_metrics, # a function to compute the metrics
train_dataset=train_dataset, # training dataset
eval_dataset=test_dataset # evaluation dataset
)
# Train the trainer
trainer.train()
And I get the following error:
ValueError: could not broadcast input array from shape (132,3) into shape (132,)
It looks like I get this error after the training is almost over or something like that, so I also tried using:
trainer.evaluate()
And I get the following:
The following columns in the evaluation set don't have a corresponding argument in `BartForSequenceClassification.forward` and have been ignored: input_sentence. If input_sentence are not expected by `BartForSequenceClassification.forward`, you can safely ignore this message.
***** Running Evaluation *****
Num examples = 132
Batch size = 8
warning: Databricks notebooks do not support updating results across cells.
ValueError: could not broadcast input array from shape (132,3) into shape (132,)
``
Why is this? How can I fix this?