XLM-R classifier predictions produce errors

Hi,

I am using tf-xlm-r-base model for a sentiment classification (multi-class) task with 4 classes. I used both trainer() api and keras native method. Initially, I got some acceptable result but later it predicts only one class for the same data set. I am following this guide. Below are my outputs and code. My inputs and labels are similar to the example mentioned in the guide(mentioned above) having class labels 0,1,2,3. I am trying this on a dataset of 1000 training data points. I also tried with 7K and 14K training data sets which did not solve the error. Only the single predicted class changed.

X_train, X_test, y_train, y_test = train_test_split(comment_texts, comment_labels, test_size=0.1, random_state=0)

model_checkpoint = "jplu/tf-xlm-roberta-base"

from transformers import AutoTokenizer     
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

train_encodings = tokenizer(X_train, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(X_test, truncation=True, padding=True)

train_dataset = tf.data.Dataset.from_tensor_slices(( # convert to dataset objects
    dict(train_encodings),
    y_train
))
val_dataset = tf.data.Dataset.from_tensor_slices((
    dict(val_encodings),
    val_labels
))
test_dataset = tf.data.Dataset.from_tensor_slices((
    dict(test_encodings),
    y_test
))

from transformers import TFAutoModelForSequenceClassification, TFTrainingArguments, TFTrainer
num_labels=4#
model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Model: "tfxlm_roberta_for_sequence_classification_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
roberta (TFRobertaMainLayer) multiple                  277453056 
_________________________________________________________________
classifier (TFRobertaClassif multiple                  593668    
=================================================================
Total params: 278,046,724
Trainable params: 278,046,724
Non-trainable params: 0

from sklearn.metrics import accuracy_score,precision_score,recall_score

def compute_metrics(p):    
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average='weighted')
    precision = precision_score(y_true=labels, y_pred=pred, average='weighted')
    #f1 = f1_score(y_true=labels, y_pred=pred)
    return {"accuracy": accuracy, "precision": precision, "recall": recall}

raining_args = TFTrainingArguments(
    output_dir='/content/drive/MyDrive/test_transformer/results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    evaluation_strategy = "epoch",
    per_device_train_batch_size=8,  # batch size per device during training
    per_device_eval_batch_size=8,   # batch size for evaluation
    warmup_steps=100,                # number of warmup steps for learning rate scheduler
    weight_decay=0.001,               # strength of weight decay
    logging_dir='/content/drive/MyDrive/test_transformer/logs',            # directory for storing logs
    logging_steps=10,
)

with training_args.strategy.scope():
    model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels, output_attentions=True)

trainer = TFTrainer(
    model=model,                         # the instantiated Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,             # evaluation dataset # should not be the test dataset although I am using it
    compute_metrics=compute_metrics
)

trainer.train()

For Keras native training,

optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
model.compile(optimizer=optimizer, loss=model.compute_loss, metrics=['accuracy'])# can also use any keras loss fn
his=model.fit(train_dataset.shuffle(1000).batch(16), epochs=10, batch_size=16)

out=model.predict(test_dataset.batch(16))
y_pred=np.array((np.argmax((tf.nn.softmax(out.logits,axis=-1)),axis=-1)))

which gives results for a test dataset like,

              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000        31
           1     0.4200    1.0000    0.5915        42
           2     0.0000    0.0000    0.0000        26
           3     0.0000    0.0000    0.0000         1

    accuracy                         0.4200       100
   macro avg     0.1050    0.2500    0.1479       100
weighted avg     0.1764    0.4200    0.2485       100

image

The logits outputs tally with these, giving high value for a one particular class,

PredictionOutput(predictions=array([[ 1.0289115 ,  1.1212691 ,  0.35617405, -2.9331322 ],
       [ 1.0295717 ,  1.1219745 ,  0.35612756, -2.934678  ],
       [ 1.0289289 ,  1.1211748 ,  0.3562293 , -2.9330313 ],
       [ 1.0290331 ,  1.121373  ,  0.35619155, -2.9334006 ],
       [ 1.0291184 ,  1.121406  ,  0.3562282 , -2.933531  ],
       [ 1.0291708 ,  1.121502  ,  0.35619184, -2.933691  ],
       [ 1.0289183 ,  1.121173  ,  0.35626405, -2.933055  ],
       [ 1.0282533 ,  1.1204101 ,  0.35632664, -2.9314802 ],
       [ 1.0289795 ,  1.1212872 ,  0.35624185, -2.933268  ],
       [ 1.0289418 ,  1.1212839 ,  0.35617486, -2.933162  ],
       [ 1.0290436 ,  1.1213527 ,  0.35621163, -2.933384  ],
       [ 1.0290844 ,  1.1214646 ,  0.35614225, -2.9335442 ],
       [ 1.0289063 ,  1.1212507 ,  0.35618058, -2.9330966 ],
       [ 1.0291401 ,  1.1214608 ,  0.3562286 , -2.9336445 ],
       [ 1.0290921 ,  1.1214745 ,  0.3561555 , -2.9335837 ],
       [ 1.0290128 ,  1.1212608 ,  0.35625243, -2.9332502 ],
       [ 1.0291014 ,  1.1213605 ,  0.3562572 , -2.9334743 ],
       [ 1.0291253 ,  1.1214253 ,  0.35621867, -2.9335544 ],
       [ 1.0290672 ,  1.1213386 ,  0.35625044, -2.933417  ],
       [ 1.0289694 ,  1.1212484 ,  0.35623503, -2.9331877 ],
       [ 1.0291487 ,  1.1214706 ,  0.35619804, -2.93363   ],
       [ 1.0289078 ,  1.1212181 ,  0.3562105 , -2.933069  ],
       [ 1.0291108 ,  1.1214286 ,  0.35619968, -2.9335427 ],
       [ 1.0291494 ,  1.121471  ,  0.35619843, -2.9336333 ],
       [ 1.0291737 ,  1.1215005 ,  0.35622108, -2.9337227 ],
       [ 1.0291069 ,  1.1215074 ,  0.3561285 , -2.9336267 ],
       [ 1.0289413 ,  1.1212887 ,  0.3561762 , -2.9331977 ],
       [ 1.0293871 ,  1.1218258 ,  0.35609013, -2.9342983 ],
       [ 1.0290092 ,  1.1213552 ,  0.3561896 , -2.9333515 ],
       [ 1.0283587 ,  1.1205777 ,  0.3562444 , -2.93175   ],
       [ 1.0289493 ,  1.1212447 ,  0.3562152 , -2.9331517 ],
       [ 1.0290315 ,  1.1213847 ,  0.3561749 , -2.93341   ],
       [ 1.0290761 ,  1.1214103 ,  0.35618582, -2.9334972 ],
       [ 1.0289084 ,  1.1212412 ,  0.35619718, -2.9331186 ],
       [ 1.0290228 ,  1.1213146 ,  0.35622573, -2.933303  ],
       [ 1.028934  ,  1.1211748 ,  0.3562406 , -2.933056  ],
       [ 1.0292109 ,  1.1215464 ,  0.356189  , -2.9337885 ],
       [ 1.0292367 ,  1.1215956 ,  0.3561677 , -2.9338856 ],
       [ 1.0290639 ,  1.1214087 ,  0.35618123, -2.9334571 ],
       [ 1.0290803 ,  1.1214298 ,  0.35617745, -2.933514  ],
       [ 1.0291355 ,  1.1214267 ,  0.35623217, -2.9335864 ],
       [ 1.0289494 ,  1.1212884 ,  0.35619822, -2.9332087 ],
       [ 1.0293379 ,  1.1216819 ,  0.35617125, -2.9340885 ],
       [ 1.0293596 ,  1.1217119 ,  0.35619378, -2.934169  ],
       [ 1.0291876 ,  1.1215029 ,  0.35619885, -2.9337032 ],
       [ 1.0288249 ,  1.1211317 ,  0.35620478, -2.9328792 ],
       [ 1.0291348 ,  1.1215305 ,  0.35612702, -2.93368   ],
       [ 1.0292716 ,  1.1215887 ,  0.3561728 , -2.9338777 ],
       [ 1.029365  ,  1.1217095 ,  0.35616824, -2.9341416 ],
       [ 1.0290942 ,  1.1214304 ,  0.35618833, -2.9335175 ],
       [ 1.0289066 ,  1.1211537 ,  0.3562669 , -2.933033  ],
       [ 1.0291075 ,  1.121446  ,  0.35618696, -2.9335594 ],
       [ 1.0294148 ,  1.121831  ,  0.3561264 , -2.9343624 ],
       [ 1.0292593 ,  1.1216727 ,  0.35611537, -2.9339767 ],
       [ 1.0290027 ,  1.1213193 ,  0.35620502, -2.9333057 ],
       [ 1.0292455 ,  1.1215812 ,  0.3561885 , -2.933863  ],
       [ 1.0291928 ,  1.1215296 ,  0.35619065, -2.9337544 ],
       [ 1.0289562 ,  1.12131   ,  0.3561774 , -2.9332416 ],
       [ 1.0292183 ,  1.1215689 ,  0.35616964, -2.9338312 ],
       [ 1.0292469 ,  1.121585  ,  0.35619286, -2.9338806 ],
       [ 1.0292145 ,  1.1215852 ,  0.35615656, -2.9338334 ],
       [ 1.0290105 ,  1.1213479 ,  0.3561993 , -2.933348  ],
       [ 1.0290651 ,  1.1213914 ,  0.35617766, -2.9334235 ],
       [ 1.0290747 ,  1.1214159 ,  0.3561734 , -2.9334788 ],
       [ 1.0290564 ,  1.1214483 ,  0.35613787, -2.933495  ],
       [ 1.0290995 ,  1.1214293 ,  0.3561933 , -2.9335384 ],
       [ 1.0292302 ,  1.1215992 ,  0.356152  , -2.933872  ],
       [ 1.0294645 ,  1.1218221 ,  0.35615602, -2.9344025 ],
       [ 1.0291889 ,  1.1215235 ,  0.35620224, -2.9337523 ],
       [ 1.0292081 ,  1.1215614 ,  0.35616958, -2.9337964 ],
       [ 1.0291911 ,  1.1215253 ,  0.35619062, -2.9337654 ],
       [ 1.028842  ,  1.1211593 ,  0.35621268, -2.9329333 ],
       [ 1.0290655 ,  1.1213945 ,  0.3562099 , -2.933466  ],
       [ 1.027889  ,  1.1199136 ,  0.35638046, -2.9304833 ],
       [ 1.0290161 ,  1.1213044 ,  0.35623237, -2.9333255 ],
       [ 1.0292108 ,  1.1215161 ,  0.3562079 , -2.9337723 ],
       [ 1.0290184 ,  1.1213537 ,  0.35618034, -2.9333456 ],
       [ 1.0289462 ,  1.1212633 ,  0.35620436, -2.9331503 ],
       [ 1.0292357 ,  1.1216115 ,  0.3561552 , -2.933901  ],
       [ 1.0290319 ,  1.1213702 ,  0.35618973, -2.933375  ],
       [ 1.0292478 ,  1.1216022 ,  0.35619214, -2.93391   ],
       [ 1.0288923 ,  1.1211102 ,  0.35628372, -2.932956  ],
       [ 1.0272568 ,  1.1193739 ,  0.3562229 , -2.929129  ],
       [ 1.0290308 ,  1.1214021 ,  0.35614964, -2.9334133 ],
       [ 1.0291088 ,  1.1214335 ,  0.356207  , -2.933562  ],
       [ 1.028747  ,  1.1209638 ,  0.3562615 , -2.932617  ],
       [ 1.0292501 ,  1.1215758 ,  0.35620457, -2.933898  ],
       [ 1.0288965 ,  1.1211853 ,  0.3562436 , -2.933039  ],
       [ 1.0290757 ,  1.121396  ,  0.3562069 , -2.9334598 ],
       [ 1.0290804 ,  1.1213692 ,  0.3562515 , -2.9334688 ],
       [ 1.0292016 ,  1.1215339 ,  0.35618612, -2.9337628 ],
       [ 1.0291742 ,  1.1215011 ,  0.35619456, -2.9337041 ],
       [ 1.0290596 ,  1.1214275 ,  0.35615835, -2.9334834 ],
       [ 1.0293068 ,  1.1216859 ,  0.35614944, -2.9340594 ],
       [ 1.0291089 ,  1.1214725 ,  0.35615313, -2.9335957 ],
       [ 1.0290648 ,  1.1213634 ,  0.35623825, -2.933438  ],
       [ 1.0292054 ,  1.1215473 ,  0.35618582, -2.9337878 ],
       [ 1.0281092 ,  1.1202368 ,  0.356357  , -2.931106  ],
       [ 1.0292268 ,  1.1216258 ,  0.3561134 , -2.9338837 ],
       [ 1.0290968 ,  1.1214894 ,  0.35614622, -2.9336073 ]],
      dtype=float32), label_ids=array([0, 1, 2, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, 2, 1, 0, 1, 2, 2, 0, 0, 2,
       1, 1, 1, 3, 0, 2, 2, 2, 0, 1, 2, 2, 2, 0, 1, 0, 0, 1, 0, 1, 1, 2,
       2, 1, 1, 1, 0, 1, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 1, 2, 2, 1, 0,
       2, 1, 2, 0, 1, 1, 0, 2, 1, 2, 0, 1, 1, 1, 2, 0, 1, 0, 1, 2, 2, 1,
       1, 0, 0, 1, 1, 1, 1, 2, 0, 2, 1, 1]), metrics={'eval_loss': 1.13915282029372, 'eval_accuracy': 0.42, 'eval_precision': 0.1764, 'eval_recall': 0.42})

Can someone please suggest a tip or tell me what could have gone wrong? (I also tried batch sizes 8/16, running on both CPU and GPU with same datasets and parameters, changing learning rates/epochs and down sampling training dataset to balance classes).

  • also found that for binary classification, (two from above four classes), it behaves similarly. And tried changing labels to float type and different loss function too.

Found what was the reason. In literature, the learning rates I found for fine tuning tasks are mostly mentioned in the ranges of e-4, e-5. Hence, I’ve been using 5e-5 etc. and thought that the code setup had the error. But it seemed that even lower learning rate (5e-6) was needed for this.