I am using tf-xlm-r-base model for a sentiment classification (multi-class) task with 4 classes. I used both trainer() api and keras native method. Initially, I got some acceptable result but later it predicts only one class for the same data set. I am following this guide. Below are my outputs and code. My inputs and labels are similar to the example mentioned in the guide(mentioned above) having class labels 0,1,2,3. I am trying this on a dataset of 1000 training data points. I also tried with 7K and 14K training data sets which did not solve the error. Only the single predicted class changed.
X_train, X_test, y_train, y_test = train_test_split(comment_texts, comment_labels, test_size=0.1, random_state=0)
model_checkpoint = "jplu/tf-xlm-roberta-base"
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
train_encodings = tokenizer(X_train, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(X_test, truncation=True, padding=True)
train_dataset = tf.data.Dataset.from_tensor_slices(( # convert to dataset objects
val_dataset = tf.data.Dataset.from_tensor_slices((
test_dataset = tf.data.Dataset.from_tensor_slices((
from transformers import TFAutoModelForSequenceClassification, TFTrainingArguments, TFTrainer
model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
Model: "tfxlm_roberta_for_sequence_classification_3"
Layer (type) Output Shape Param #
roberta (TFRobertaMainLayer) multiple 277453056
classifier (TFRobertaClassif multiple 593668
Total params: 278,046,724
Trainable params: 278,046,724
Non-trainable params: 0
from sklearn.metrics import accuracy_score,precision_score,recall_score
def compute_metrics(p):
pred, labels = p
pred = np.argmax(pred, axis=1)
accuracy = accuracy_score(y_true=labels, y_pred=pred)
recall = recall_score(y_true=labels, y_pred=pred, average='weighted')
precision = precision_score(y_true=labels, y_pred=pred, average='weighted')
#f1 = f1_score(y_true=labels, y_pred=pred)
return {"accuracy": accuracy, "precision": precision, "recall": recall}
raining_args = TFTrainingArguments(
output_dir='/content/drive/MyDrive/test_transformer/results', # output directory
num_train_epochs=3, # total number of training epochs
evaluation_strategy = "epoch",
per_device_train_batch_size=8, # batch size per device during training
per_device_eval_batch_size=8, # batch size for evaluation
warmup_steps=100, # number of warmup steps for learning rate scheduler
weight_decay=0.001, # strength of weight decay
logging_dir='/content/drive/MyDrive/test_transformer/logs', # directory for storing logs
with training_args.strategy.scope():
model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels, output_attentions=True)
trainer = TFTrainer(
model=model, # the instantiated Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset, # evaluation dataset # should not be the test dataset although I am using it
For Keras native training,
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
model.compile(optimizer=optimizer, loss=model.compute_loss, metrics=['accuracy'])# can also use any keras loss fn
his=model.fit(train_dataset.shuffle(1000).batch(16), epochs=10, batch_size=16)
which gives results for a test dataset like,
precision recall f1-score support
0 0.0000 0.0000 0.0000 31
1 0.4200 1.0000 0.5915 42
2 0.0000 0.0000 0.0000 26
3 0.0000 0.0000 0.0000 1
accuracy 0.4200 100
macro avg 0.1050 0.2500 0.1479 100
weighted avg 0.1764 0.4200 0.2485 100
The logits outputs tally with these, giving high value for a one particular class,
PredictionOutput(predictions=array([[ 1.0289115 , 1.1212691 , 0.35617405, -2.9331322 ],
[ 1.0295717 , 1.1219745 , 0.35612756, -2.934678 ],
[ 1.0289289 , 1.1211748 , 0.3562293 , -2.9330313 ],
[ 1.0290331 , 1.121373 , 0.35619155, -2.9334006 ],
[ 1.0291184 , 1.121406 , 0.3562282 , -2.933531 ],
[ 1.0291708 , 1.121502 , 0.35619184, -2.933691 ],
[ 1.0289183 , 1.121173 , 0.35626405, -2.933055 ],
[ 1.0282533 , 1.1204101 , 0.35632664, -2.9314802 ],
[ 1.0289795 , 1.1212872 , 0.35624185, -2.933268 ],
[ 1.0289418 , 1.1212839 , 0.35617486, -2.933162 ],
[ 1.0290436 , 1.1213527 , 0.35621163, -2.933384 ],
[ 1.0290844 , 1.1214646 , 0.35614225, -2.9335442 ],
[ 1.0289063 , 1.1212507 , 0.35618058, -2.9330966 ],
[ 1.0291401 , 1.1214608 , 0.3562286 , -2.9336445 ],
[ 1.0290921 , 1.1214745 , 0.3561555 , -2.9335837 ],
[ 1.0290128 , 1.1212608 , 0.35625243, -2.9332502 ],
[ 1.0291014 , 1.1213605 , 0.3562572 , -2.9334743 ],
[ 1.0291253 , 1.1214253 , 0.35621867, -2.9335544 ],
[ 1.0290672 , 1.1213386 , 0.35625044, -2.933417 ],
[ 1.0289694 , 1.1212484 , 0.35623503, -2.9331877 ],
[ 1.0291487 , 1.1214706 , 0.35619804, -2.93363 ],
[ 1.0289078 , 1.1212181 , 0.3562105 , -2.933069 ],
[ 1.0291108 , 1.1214286 , 0.35619968, -2.9335427 ],
[ 1.0291494 , 1.121471 , 0.35619843, -2.9336333 ],
[ 1.0291737 , 1.1215005 , 0.35622108, -2.9337227 ],
[ 1.0291069 , 1.1215074 , 0.3561285 , -2.9336267 ],
[ 1.0289413 , 1.1212887 , 0.3561762 , -2.9331977 ],
[ 1.0293871 , 1.1218258 , 0.35609013, -2.9342983 ],
[ 1.0290092 , 1.1213552 , 0.3561896 , -2.9333515 ],
[ 1.0283587 , 1.1205777 , 0.3562444 , -2.93175 ],
[ 1.0289493 , 1.1212447 , 0.3562152 , -2.9331517 ],
[ 1.0290315 , 1.1213847 , 0.3561749 , -2.93341 ],
[ 1.0290761 , 1.1214103 , 0.35618582, -2.9334972 ],
[ 1.0289084 , 1.1212412 , 0.35619718, -2.9331186 ],
[ 1.0290228 , 1.1213146 , 0.35622573, -2.933303 ],
[ 1.028934 , 1.1211748 , 0.3562406 , -2.933056 ],
[ 1.0292109 , 1.1215464 , 0.356189 , -2.9337885 ],
[ 1.0292367 , 1.1215956 , 0.3561677 , -2.9338856 ],
[ 1.0290639 , 1.1214087 , 0.35618123, -2.9334571 ],
[ 1.0290803 , 1.1214298 , 0.35617745, -2.933514 ],
[ 1.0291355 , 1.1214267 , 0.35623217, -2.9335864 ],
[ 1.0289494 , 1.1212884 , 0.35619822, -2.9332087 ],
[ 1.0293379 , 1.1216819 , 0.35617125, -2.9340885 ],
[ 1.0293596 , 1.1217119 , 0.35619378, -2.934169 ],
[ 1.0291876 , 1.1215029 , 0.35619885, -2.9337032 ],
[ 1.0288249 , 1.1211317 , 0.35620478, -2.9328792 ],
[ 1.0291348 , 1.1215305 , 0.35612702, -2.93368 ],
[ 1.0292716 , 1.1215887 , 0.3561728 , -2.9338777 ],
[ 1.029365 , 1.1217095 , 0.35616824, -2.9341416 ],
[ 1.0290942 , 1.1214304 , 0.35618833, -2.9335175 ],
[ 1.0289066 , 1.1211537 , 0.3562669 , -2.933033 ],
[ 1.0291075 , 1.121446 , 0.35618696, -2.9335594 ],
[ 1.0294148 , 1.121831 , 0.3561264 , -2.9343624 ],
[ 1.0292593 , 1.1216727 , 0.35611537, -2.9339767 ],
[ 1.0290027 , 1.1213193 , 0.35620502, -2.9333057 ],
[ 1.0292455 , 1.1215812 , 0.3561885 , -2.933863 ],
[ 1.0291928 , 1.1215296 , 0.35619065, -2.9337544 ],
[ 1.0289562 , 1.12131 , 0.3561774 , -2.9332416 ],
[ 1.0292183 , 1.1215689 , 0.35616964, -2.9338312 ],
[ 1.0292469 , 1.121585 , 0.35619286, -2.9338806 ],
[ 1.0292145 , 1.1215852 , 0.35615656, -2.9338334 ],
[ 1.0290105 , 1.1213479 , 0.3561993 , -2.933348 ],
[ 1.0290651 , 1.1213914 , 0.35617766, -2.9334235 ],
[ 1.0290747 , 1.1214159 , 0.3561734 , -2.9334788 ],
[ 1.0290564 , 1.1214483 , 0.35613787, -2.933495 ],
[ 1.0290995 , 1.1214293 , 0.3561933 , -2.9335384 ],
[ 1.0292302 , 1.1215992 , 0.356152 , -2.933872 ],
[ 1.0294645 , 1.1218221 , 0.35615602, -2.9344025 ],
[ 1.0291889 , 1.1215235 , 0.35620224, -2.9337523 ],
[ 1.0292081 , 1.1215614 , 0.35616958, -2.9337964 ],
[ 1.0291911 , 1.1215253 , 0.35619062, -2.9337654 ],
[ 1.028842 , 1.1211593 , 0.35621268, -2.9329333 ],
[ 1.0290655 , 1.1213945 , 0.3562099 , -2.933466 ],
[ 1.027889 , 1.1199136 , 0.35638046, -2.9304833 ],
[ 1.0290161 , 1.1213044 , 0.35623237, -2.9333255 ],
[ 1.0292108 , 1.1215161 , 0.3562079 , -2.9337723 ],
[ 1.0290184 , 1.1213537 , 0.35618034, -2.9333456 ],
[ 1.0289462 , 1.1212633 , 0.35620436, -2.9331503 ],
[ 1.0292357 , 1.1216115 , 0.3561552 , -2.933901 ],
[ 1.0290319 , 1.1213702 , 0.35618973, -2.933375 ],
[ 1.0292478 , 1.1216022 , 0.35619214, -2.93391 ],
[ 1.0288923 , 1.1211102 , 0.35628372, -2.932956 ],
[ 1.0272568 , 1.1193739 , 0.3562229 , -2.929129 ],
[ 1.0290308 , 1.1214021 , 0.35614964, -2.9334133 ],
[ 1.0291088 , 1.1214335 , 0.356207 , -2.933562 ],
[ 1.028747 , 1.1209638 , 0.3562615 , -2.932617 ],
[ 1.0292501 , 1.1215758 , 0.35620457, -2.933898 ],
[ 1.0288965 , 1.1211853 , 0.3562436 , -2.933039 ],
[ 1.0290757 , 1.121396 , 0.3562069 , -2.9334598 ],
[ 1.0290804 , 1.1213692 , 0.3562515 , -2.9334688 ],
[ 1.0292016 , 1.1215339 , 0.35618612, -2.9337628 ],
[ 1.0291742 , 1.1215011 , 0.35619456, -2.9337041 ],
[ 1.0290596 , 1.1214275 , 0.35615835, -2.9334834 ],
[ 1.0293068 , 1.1216859 , 0.35614944, -2.9340594 ],
[ 1.0291089 , 1.1214725 , 0.35615313, -2.9335957 ],
[ 1.0290648 , 1.1213634 , 0.35623825, -2.933438 ],
[ 1.0292054 , 1.1215473 , 0.35618582, -2.9337878 ],
[ 1.0281092 , 1.1202368 , 0.356357 , -2.931106 ],
[ 1.0292268 , 1.1216258 , 0.3561134 , -2.9338837 ],
[ 1.0290968 , 1.1214894 , 0.35614622, -2.9336073 ]],
dtype=float32), label_ids=array([0, 1, 2, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, 2, 1, 0, 1, 2, 2, 0, 0, 2,
1, 1, 1, 3, 0, 2, 2, 2, 0, 1, 2, 2, 2, 0, 1, 0, 0, 1, 0, 1, 1, 2,
2, 1, 1, 1, 0, 1, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 1, 2, 2, 1, 0,
2, 1, 2, 0, 1, 1, 0, 2, 1, 2, 0, 1, 1, 1, 2, 0, 1, 0, 1, 2, 2, 1,
1, 0, 0, 1, 1, 1, 1, 2, 0, 2, 1, 1]), metrics={'eval_loss': 1.13915282029372, 'eval_accuracy': 0.42, 'eval_precision': 0.1764, 'eval_recall': 0.42})
Can someone please suggest a tip or tell me what could have gone wrong? (I also tried batch sizes 8/16, running on both CPU and GPU with same datasets and parameters, changing learning rates/epochs and down sampling training dataset to balance classes).