Loss not decrease on SST2

HI,

The following script is used to fine-tune a BertForSequenceClassification model on SST2.

The script is adapted from this colab that presents an example of fine-tuning BertForQuestionAnswering using squad dataset. In that colab, loss works fine. When I adapt it to SST2, the loss fails to decrease as it should. I attach the adapted script below and appreciate anyone pointing out what I miss?

import torch
from datasets import load_dataset
from transformers import BertForSequenceClassification
from transformers import BertTokenizerFast
# Load our training dataset and tokenizer
dataset = load_dataset("glue", 'sst2')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
del dataset["test"] # let's remove it in this demo

# Tokenize our training dataset
def convert_to_features(example_batch):
    encodings = tokenizer(example_batch["sentence"])
    encodings.update({"labels": example_batch["label"]})
    return encodings

encoded_dataset = dataset.map(convert_to_features, batched=True)
# Format our dataset to outputs torch.Tensor to train a pytorch model
columns = ['input_ids', 'token_type_ids', 'attention_mask', 'labels']
encoded_dataset.set_format(type='torch', columns=columns)

# Instantiate a PyTorch Dataloader around our dataset
# Let's do dynamic batching (pad on the fly with our own collate_fn)
def collate_fn(examples):
    return tokenizer.pad(examples, return_tensors='pt')

dataloader = torch.utils.data.DataLoader(encoded_dataset['train'], collate_fn=collate_fn, batch_size=8)
# Now let's train our model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Let's load a pretrained Bert model and a simple optimizer
model = BertForSequenceClassification.from_pretrained('bert-base-cased', return_dict=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model.train().to(device)
for i, batch in enumerate(dataloader):
    batch.to(device)
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    model.zero_grad()
    print(f'Step {i} - loss: {loss:.3}')


In case needed.

  • datasets == 1.0.2
  • transformers == 3.2.0