Broadcast error in multi-class classification fine-tuning

I have a dataset of text documents, each document has 1 of 30 labels. I want to set up fine-tuning for multiclass (single label prediction) and multiclass, multi-label predictions for this dataset; this post is mainly concerned with the first case, of making single-label predictions after fine-tuning.

I have been fine-tuning with some checkpoints (e.g. microsoft/mpnet-base) and that runs without error and I’m able to evaluate predictions on a hold-out set.

But the predictions from MPNet were biased towards always predicting a single class, and that didn’t quite make sense as it’s not the majority class, so I’m trying to do some troubleshooting. In particular, I am trying to use a different checkpoint that was specifically trained on a multi-class task, and joeddav/bart-large-mnli-yahoo-answers looked good.

I ran into one error because the new head had a different number of classes (side note: I am not sure why I didn’t run into this error with MPNet, can anyone explain the difference?), but I found here in the forum the ignore_mismatched_sizes=True setting which seemed to stop the model definition from throwing an error.

However, now the model starts fitting but throws a new error when computing accuracy metrics - it looks like it is getting a 2D array with 30 columns, and it can’t convert that into a single max probability class membership for each observation. I’m not sure how to fix this, can anyone make a suggestion? :pray:

checkpoint = "joeddav/bart-large-mnli-yahoo-answers"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

model_single = AutoModelForSequenceClassification.from_pretrained(
    checkpoint, 
    num_labels=n_labels, 
    problem_type="single_label_classification",
    ignore_mismatched_sizes=True
)

Some weights of BartForSequenceClassification were not initialized from the model checkpoint at joeddav/bart-large-mnli-yahoo-answers and are newly initialized because the shapes did not match:
classification_head.out_proj.weight: found shape torch.Size([3, 1024]) in the checkpoint and torch.Size([30, 1024]) in the model instantiated
classification_head.out_proj.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([30]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BartForSequenceClassification were not initialized from the model checkpoint at joeddav/bart-large-mnli-yahoo-answers and are newly initialized because the shapes did not match:
classification_head.out_proj.weight: found shape torch.Size([3, 1024]) in the checkpoint and torch.Size([30, 1024]) in the model instantiated
classification_head.out_proj.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([30]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    f1 = f1_score(labels, predictions, average="weighted")
    f1_micro = f1_score(labels, predictions, average='micro')
    f1_macro = f1_score(labels, predictions, average='macro')
    acc = accuracy_score(labels, predictions)
    return {"accuracy": acc, "f1": f1, "f1_macro": f1_macro, "f1_micro": f1_micro}

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=6,              # total number of training epochs
    per_device_train_batch_size=2,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
    evaluation_strategy="epoch",
)
trainer_singleclass = Trainer(
    model=model_single,                         
    args=training_args,                  
    train_dataset=train_dataset_singlelabel,         
    eval_dataset=val_dataset_singlelabel,            
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
trainer_singleclass.train()
***** Running training *****
  Num examples = 864
  Num Epochs = 6
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 324
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '

***** Running Evaluation *****
  Num examples = 216
  Batch size = 512
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-25-582a21eea64d> in <module>
      1 # train single class fine-tune on MPNet (or whatever :checkpoint: is)
----> 2 trainer_singleclass.train()

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1488 
   1489             self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 1490             self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
   1491 
   1492             if DebugOption.TPU_METRICS_DEBUG in self.args.debug:

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/transformers/trainer.py in _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)
   1600         metrics = None
   1601         if self.control.should_evaluate:
-> 1602             metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   1603             self._report_to_hp_search(trial, epoch, metrics)
   1604 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/transformers/trainer.py in evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   2262             prediction_loss_only=True if self.compute_metrics is None else None,
   2263             ignore_keys=ignore_keys,
-> 2264             metric_key_prefix=metric_key_prefix,
   2265         )
   2266 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/transformers/trainer.py in evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   2503         # Metrics!
   2504         if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
-> 2505             metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
   2506         else:
   2507             metrics = {}

<ipython-input-16-8829ed159344> in compute_metrics(eval_preds)
      1 def compute_metrics(eval_preds):
      2     logits, labels = eval_preds
----> 3     predictions = np.argmax(logits, axis=-1)
      4     f1 = f1_score(labels, predictions, average="weighted")
      5     f1_micro = f1_score(labels, predictions, average='micro')

<__array_function__ internals> in argmax(*args, **kwargs)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/numpy/core/fromnumeric.py in argmax(a, axis, out)
   1151 
   1152     """
-> 1153     return _wrapfunc(a, 'argmax', axis=axis, out=out)
   1154 
   1155 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     56     bound = getattr(obj, method, None)
     57     if bound is None:
---> 58         return _wrapit(obj, method, *args, **kwds)
     59 
     60     try:

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/numpy/core/fromnumeric.py in _wrapit(obj, method, *args, **kwds)
     45     except AttributeError:
     46         wrap = None
---> 47     result = getattr(asarray(obj), method)(*args, **kwds)
     48     if wrap:
     49         if not isinstance(result, mu.ndarray):

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
     83 
     84     """
---> 85     return array(a, dtype, copy=False, order=order)
     86 
     87 

ValueError: could not broadcast input array from shape (216,30) into shape (216)