BERT token classification / regression question

Hi all, I’m finetuning BERT for emotion classification / regression on token level, see code below. Specifically, I want to predict intensity values for a list of specific emotions for each token, e.g. "This party sucks!" -> This (sad: 0.0, angry: 0.0, happy: 0.0, ...), party (sad: 0.5, angry: 0.6, happy: 0.0, ...), sucks (sad: 0.5, angry: 0.6, happy: 0.0, ...). My question is, can the model accurately predict the emotion on token level when it hasn’t “read” the whole sentence yet? E.g. if the model predicts the emotion for “party” it hasn’t read “sucks” yet and doesn’t know if it’s positive or negative, right?

class BertForTokenEmotionIntensity(BertPreTrainedModel):
    def __init__(self, config, num_emotions=12):
        super().__init__(config)
        self.num_emotions = num_emotions
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.num_emotions)  # Output layer for 12 emotions
        self.init_weights()
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # Get BERT's outputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state  # Shape: (batch_size, seq_len, hidden_size)
        
        # Apply dropout and classification layer
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)  # Shape: (batch_size, seq_len, num_emotions)
        loss = None
        
        if labels is not None:
            # Calculate mean squared error loss for intensity prediction
            loss_fn = MSELoss()
            active_loss = attention_mask.view(-1) == 1  # Mask out non-target tokens
            active_logits = logits.view(-1, self.num_emotions)[active_loss]
            active_labels = labels.view(-1, self.num_emotions)[active_loss]
            loss = loss_fn(active_logits, active_labels)

        return {"loss": loss, "logits": logits}

model.train()
for epoch in range(100):
    for batch in dataloader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs['loss']

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
1 Like