Hello, I am trying to fine-tune distilbert model with “(tweet_eval, irony)” dataset. I have been scratching my head for about two days as the accuracy is not sufficient to predict correct output after training it. Here is what I am trying to do in code,
!pip install datasets
from datasets import load_dataset
from transformers import DistilBertTokenizer
tweet_dataset = load_dataset('tweet_eval', 'irony')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
def tokenize_function(example):
return tokenizer(example["text"], padding=True, truncation=True, max_length=512)
tokenized_tweet_dataset = tweet_dataset.map(tokenize_function, batched=True)
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
import numpy as np
from sklearn.metrics import accuracy_score
def model_init():
return DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
model_name = "distilbert-finetuned-tweet-eval"
training_args = TrainingArguments(
output_dir=model_name,
learning_rate=5e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=5,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy = "epoch",
)
trainer = Trainer(
model_init=model_init,
args=training_args,
train_dataset=tokenized_tweet_dataset['train'],
eval_dataset=tokenized_tweet_dataset['validation'],
compute_metrics=lambda p: {"accuracy": accuracy_score(p.label_ids,
np.argmax(p.predictions, axis=1))},
tokenizer=tokenizer
)
trainer.train()
DatasetDict({
train: Dataset({
features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 2862
})
test: Dataset({
features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 784
})
validation: Dataset({
features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 955
})
})
The accuracy i am getting is around 0.68-0.70, which doesn’t predict the correct irony always. I did a hyperparameter search to find a better parameter for the training_args, however result doesn’t imrpove.
Please help me understand why is it happening and potential solution to this problem.