HI,
The following script is used to fine-tune a BertForSequenceClassification model on SST2.
The script is adapted from this colab that presents an example of fine-tuning BertForQuestionAnswering using squad dataset. In that colab, loss works fine. When I adapt it to SST2, the loss fails to decrease as it should. I attach the adapted script below and appreciate anyone pointing out what I miss?
import torch
from datasets import load_dataset
from transformers import BertForSequenceClassification
from transformers import BertTokenizerFast
# Load our training dataset and tokenizer
dataset = load_dataset("glue", 'sst2')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
del dataset["test"] # let's remove it in this demo
# Tokenize our training dataset
def convert_to_features(example_batch):
encodings = tokenizer(example_batch["sentence"])
encodings.update({"labels": example_batch["label"]})
return encodings
encoded_dataset = dataset.map(convert_to_features, batched=True)
# Format our dataset to outputs torch.Tensor to train a pytorch model
columns = ['input_ids', 'token_type_ids', 'attention_mask', 'labels']
encoded_dataset.set_format(type='torch', columns=columns)
# Instantiate a PyTorch Dataloader around our dataset
# Let's do dynamic batching (pad on the fly with our own collate_fn)
def collate_fn(examples):
return tokenizer.pad(examples, return_tensors='pt')
dataloader = torch.utils.data.DataLoader(encoded_dataset['train'], collate_fn=collate_fn, batch_size=8)
# Now let's train our model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Let's load a pretrained Bert model and a simple optimizer
model = BertForSequenceClassification.from_pretrained('bert-base-cased', return_dict=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model.train().to(device)
for i, batch in enumerate(dataloader):
batch.to(device)
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
model.zero_grad()
print(f'Step {i} - loss: {loss:.3}')
In case needed.
- datasets == 1.0.2
- transformers == 3.2.0