Hi all, I’m finetuning BERT for emotion classification / regression on token level, see code below. Specifically, I want to predict intensity values for a list of specific emotions for each token, e.g. "This party sucks!" -> This (sad: 0.0, angry: 0.0, happy: 0.0, ...), party (sad: 0.5, angry: 0.6, happy: 0.0, ...), sucks (sad: 0.5, angry: 0.6, happy: 0.0, ...)
. My question is, can the model accurately predict the emotion on token level when it hasn’t “read” the whole sentence yet? E.g. if the model predicts the emotion for “party” it hasn’t read “sucks” yet and doesn’t know if it’s positive or negative, right?
class BertForTokenEmotionIntensity(BertPreTrainedModel):
def __init__(self, config, num_emotions=12):
super().__init__(config)
self.num_emotions = num_emotions
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.num_emotions) # Output layer for 12 emotions
self.init_weights()
def forward(self, input_ids, attention_mask=None, labels=None):
# Get BERT's outputs
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs.last_hidden_state # Shape: (batch_size, seq_len, hidden_size)
# Apply dropout and classification layer
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) # Shape: (batch_size, seq_len, num_emotions)
loss = None
if labels is not None:
# Calculate mean squared error loss for intensity prediction
loss_fn = MSELoss()
active_loss = attention_mask.view(-1) == 1 # Mask out non-target tokens
active_logits = logits.view(-1, self.num_emotions)[active_loss]
active_labels = labels.view(-1, self.num_emotions)[active_loss]
loss = loss_fn(active_logits, active_labels)
return {"loss": loss, "logits": logits}
model.train()
for epoch in range(100):
for batch in dataloader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs['loss']
optimizer.zero_grad()
loss.backward()
optimizer.step()