Hello,
I’m trying to fine-tune a Custom BertModel on a sequence classification task, but I’m having some issues getting the Trainer to log the validation loss. Specifically, the log looks like this:
Here is the code I’m using:
from datasets import load_dataset
import torch.nn as nn
from transformers import (
AutoModel,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments
)
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
# print(inputs)
criterion = nn.CrossEntropyLoss()
outputs = model(**inputs)
loss = criterion(outputs, inputs['labels'])
return (loss, outputs) if return_outputs else loss
class CustomModel(nn.Module):
def __init__(self, d_out):
super(CustomModel, self).__init__()
self.encoder = AutoModel.from_pretrained('bert-base-uncased')
d_in = self.encoder.config.hidden_size
self.fc1 = nn.Linear(d_in, d_in)
self.activation = nn.ReLU()
self.fc2 = nn.Linear(d_in, d_out)
def forward(self, input_ids, **kwargs):
x = self.encoder(input_ids=input_ids).pooler_output
x = self.fc1(x)
x = self.activation(x)
logits = self.fc2(x)
return logits
def tokenize_function(example):
return tokenizer(example["sentence1"], example["sentence2"], truncation=True)
raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
training_args = TrainingArguments(
'test-trainer',
report_to='none',
save_strategy='no',
evaluation_strategy='steps',
eval_steps=50
)
# model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
model = CustomModel(2)
trainer = CustomTrainer(
model,
training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.train()
When I use a regular BertModelForSequenceClassification, it works fine. Is there something I’m missing?