High inconsistancies while Training

Hi,

While I am training models using K-fold (stratified classes) samples. I am getting really different performances between each sample. I was wondering whether it is due to a bad luck sampling or due to my training settings.

{

  "model_name_or_path": "bert-base-multilingual-cased",
  "seed": 42,
  "load_best_model_at_end": "True",
  "metric_for_best_model": "f1_avg",
  "log_level": "error",
  "num_train_epochs":3.0,
  "evaluation_strategy": "epoch",
  "max_steps": -1,
  "save_strategy": "epoch",
  "per_device_eval_batch_size":32,
  "per_device_train_batch_size":32
}

  # Initialize our Trainer
  trainer = Trainer(
      model=model,
      args=training_args,
      train_dataset=data_args.train_file if training_args.do_train else None,
      eval_dataset=data_args.validation_file if training_args.do_eval else None,
      compute_metrics=compute_metrics
  )

  trainer.train()
  trainer.save_model(output_dir=training_args.output_dir)

  if training_args.do_predict:
    print("*** Predict ***")
    predict_dataset = data_args.test_file.remove_columns(["label","text"])
    predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
    predictions = np.argmax(predictions, axis=-1)
    gold = list((label2id[i] for i in data_args.test_file['label']))

    return compute_metrics_final(predictions, gold)
  else: return

accuracy_metric = load_metric("accuracy")
precision_metric = load_metric("precision", average=None)
recall_metric = load_metric("recall", average=None)
f1_metric = load_metric("f1", average=None)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    print(predictions.shape)
    results = {
        "precision_avg": precision_metric.compute(predictions=predictions, references=labels, average="weighted")["precision"],
        "recall_avg": recall_metric.compute(predictions=predictions, references=labels, average="weighted")["recall"],
        "f1_avg": f1_metric.compute(predictions=predictions, references=labels, average="weighted")["f1"]}
    return results

Result samples

src/datasets/fr/classic/training_0/5/
{'eval_loss': 1.514719009399414, 'eval_precision_avg': 0.2334592880637693, 'eval_recall_avg': 0.28966914247130315, 'eval_f1_avg': 0.1830019082583156, 'eval_runtime': 3.3024, 'eval_samples_per_second': 448.467, 'eval_steps_per_second': 3.634, 'epoch': 1.0}
{'eval_loss': 1.4260002374649048, 'eval_precision_avg': 0.26976654143954865, 'eval_recall_avg': 0.36664415935178934, 'eval_f1_avg': 0.2938495440080509, 'eval_runtime': 3.2914, 'eval_samples_per_second': 449.959, 'eval_steps_per_second': 3.646, 'epoch': 2.0}
{'eval_loss': 1.3929047584533691, 'eval_precision_avg': 0.36920133539327155, 'eval_recall_avg': 0.39702903443619175, 'eval_f1_avg': 0.3428423156696124, 'eval_runtime': 3.2958, 'eval_samples_per_second': 449.359, 'eval_steps_per_second': 3.641, 'epoch': 3.0}
{'train_runtime': 95.8073, 'train_samples_per_second': 139.029, 'train_steps_per_second': 1.096, 'train_loss': 1.4999622163318453, 'epoch': 3.0}
*** Predict ***
{'accuracy': 0.37635135135135134, 'precision': [0.3157894736842105, 0.32455089820359284, 0.0, 0.4900398406374502, 0.43026706231454004], 'recall': [0.06143344709897611, 0.7324324324324324, 0.0, 0.634020618556701, 0.35365853658536583], 'f1': [0.10285714285714284, 0.44979253112033196, 0.0, 0.5528089887640449, 0.38821954484605087], 'precision_avg': 0.327086062673939, 'recall_avg': 0.37635135135135134, 'f1_avg': 0.31282117336403464}

src/datasets/fr/classic/training_0/7/
{'eval_loss': 1.5796321630477905, 'eval_precision_avg': 0.1315343569350999, 'eval_recall_avg': 0.24848075624577987, 'eval_f1_avg': 0.11325190173413503, 'eval_runtime': 3.2921, 'eval_samples_per_second': 449.868, 'eval_steps_per_second': 3.645, 'epoch': 1.0}
{'eval_loss': 1.574394702911377, 'eval_precision_avg': 0.06074011528426009, 'eval_recall_avg': 0.24645509790681971, 'eval_f1_avg': 0.09746057501190596, 'eval_runtime': 3.2938, 'eval_samples_per_second': 449.635, 'eval_steps_per_second': 3.643, 'epoch': 2.0}
{'eval_loss': 1.569918155670166, 'eval_precision_avg': 0.06074011528426009, 'eval_recall_avg': 0.24645509790681971, 'eval_f1_avg': 0.09746057501190596, 'eval_runtime': 3.2905, 'eval_samples_per_second': 450.077, 'eval_steps_per_second': 3.647, 'epoch': 3.0}
{'train_runtime': 95.9071, 'train_samples_per_second': 138.884, 'train_steps_per_second': 1.095, 'train_loss': 1.5773372105189731, 'epoch': 3.0}
*** Predict ***
{'accuracy': 0.2472972972972973, 'precision': [0.0, 0.2510431154381085, 0.0, 0.0, 0.11904761904761904], 'recall': [0.0, 0.9652406417112299, 0.0, 0.0, 0.012626262626262626], 'f1': [0.0, 0.3984547461368654, 0.0, 0.0, 0.0228310502283105], 'precision_avg': 0.09529255561939845, 'recall_avg': 0.2472972972972973, 'f1_avg': 0.10679943982810718}

src/datasets/fr/classic/training_0/6/
{'eval_loss': 1.4923285245895386, 'eval_precision_avg': 0.2233032981309156, 'eval_recall_avg': 0.32478055367994596, 'eval_f1_avg': 0.23654619315539652, 'eval_runtime': 3.2891, 'eval_samples_per_second': 450.281, 'eval_steps_per_second': 3.648, 'epoch': 1.0}
{'eval_loss': 1.438174843788147, 'eval_precision_avg': 0.34855820876781185, 'eval_recall_avg': 0.3639432815665091, 'eval_f1_avg': 0.29301501204942587, 'eval_runtime': 3.6235, 'eval_samples_per_second': 408.724, 'eval_steps_per_second': 3.312, 'epoch': 2.0}
{'eval_loss': 1.4233169555664062, 'eval_precision_avg': 0.332615430700565, 'eval_recall_avg': 0.37339635381498987, 'eval_f1_avg': 0.3141901343545122, 'eval_runtime': 3.2869, 'eval_samples_per_second': 450.571, 'eval_steps_per_second': 3.651, 'epoch': 3.0}
{'train_runtime': 95.8172, 'train_samples_per_second': 139.015, 'train_steps_per_second': 1.096, 'train_loss': 1.471975853329613, 'epoch': 3.0}
*** Predict ***
{'accuracy': 0.377027027027027, 'precision': [0.42857142857142855, 0.32778489116517284, 0.0, 0.4107744107744108, 0.4523076923076923], 'recall': [0.11498257839721254, 0.6701570680628273, 0.0, 0.7176470588235294, 0.3348519362186788], 'f1': [0.1813186813186813, 0.4402407566638005, 0.0, 0.5224839400428265, 0.38481675392670156], 'precision_avg': 0.34905983458096135, 'recall_avg': 0.377027027027027, 'f1_avg': 0.3229508482196864}

Thanks for your help !