MTL model for find entity names and make corrections

I want to train an MTL model with NER and a seq2seq architecture to identify entity names and correct typos within them. The NER model is word-based, while the seq2seq model operates at the character level. The NER model performs well, but the seq2seq typo correction model performs poorly. Can anyone help me understand the issue with the model or suggest ways to improve its performance?

toy_data = [
    {
        "sentence": "The superstore OBAD sold 1000 new deserts last month",
        "correction": "oba",  
        "entity_type": "store"
    },
    {
        "sentence": "The green organization of OBAD hired new volunteers to help them for clean forests",
        "correction": "obada", 
        "entity_type": "organization"
    },
    {
        "sentence": "The tech giant applee inc unveiled its latest gadget",
        "correction": "apple", 
        "entity_type": "company"
    }]


import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader



# Building vocabularies for words and characters
def build_word_vocab(sentences):
    vocab = set()
    for s in sentences:
        for w in s.lower().split():
            vocab.add(w)
    vocab.add("<pad>")
    vocab.add("<unk>")
    vocab.add("<start>")
    vocab.add("<end>")
    return sorted(list(vocab))


def build_char_vocab(texts):
    vocab = set()
    for text in texts:
        vocab.update(list(text))
    vocab.add("<pad>")
    vocab.add("<start>")
    vocab.add("<end>")
    vocab.add("<unk>")  # Add <unk> for unknown characters
    return sorted(list(vocab))

# Collect all sentences and correction targets.
all_sentences = [item["sentence"].lower() for item in toy_data]
all_corrections = [item["correction"] for item in toy_data]

word_vocab = build_word_vocab(all_sentences)
char_vocab = build_char_vocab(all_corrections)

word2idx = {w: idx for idx, w in enumerate(word_vocab)}
idx2word = {idx: w for w, idx in word2idx.items()}
char2idx = {ch: idx for idx, ch in enumerate(char_vocab)}
idx2char = {idx: ch for ch, idx in char2idx.items()}

word_vocab_size = len(word2idx)
char_vocab_size = len(char2idx)

# Define special token indices.
WORD_PAD_IDX = word2idx["<pad>"]
CHAR_PAD_IDX = char2idx["<pad>"]
SOS_token = char2idx["<start>"]
EOS_token = char2idx["<end>"]

def generate_ner_labels(sentence):
    tokens = sentence.split()
    labels = []
    for token in tokens:
        if token == "obad":
            labels.append(1)  # B-COMPANY; (since we assume a one-token entity)
        else:
            labels.append(0)
    return tokens, labels

# Find maximum sentence lengths (word count) and maximum correction length (char count).
max_word_len = max(len(s.split()) for s in all_sentences)
max_char_len = max(len(txt) for txt in all_corrections)

# Functions to encode and pad word sequences and label sequences.
def encode_words(sentence, word2idx, max_len):
    tokens = sentence.split()
    ids = [word2idx.get(t, word2idx["<unk>"]) for t in tokens]
    if len(ids) < max_len:
        ids = ids + [WORD_PAD_IDX] * (max_len - len(ids))
    else:
        ids = ids[:max_len]
    return ids

def encode_labels(labels, max_len, pad_idx=WORD_PAD_IDX):
    if len(labels) < max_len:
        labels = labels + [pad_idx] * (max_len - len(labels))  # Replace -100 with pad_idx
    else:
        labels = labels[:max_len]
    return labels

def encode_chars(text, char2idx, max_len):
    ids = [char2idx[ch] for ch in text if ch in char2idx]
    if len(ids) < max_len:
        ids = ids + [CHAR_PAD_IDX] * (max_len - len(ids))
    else:
        ids = ids[:max_len]
    return ids

# Create a dataset class that returns:
#  (1) word–level tensor for NER,
#  (2) word–level NER labels,
#  (3) char–level tensor for correction target.
class MultiTaskDataset(Dataset):
    def __init__(self, data, word2idx, char2idx, max_word_len, max_char_len):
        self.data = data
        self.word2idx = word2idx
        self.char2idx = char2idx
        self.max_word_len = max_word_len
        self.max_char_len = max_char_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        sentence = item["sentence"].lower()
        correction = item["correction"]
        word_tokens, ner_labels = generate_ner_labels(sentence)
        word_ids = encode_words(sentence, self.word2idx, self.max_word_len)
        ner_ids = encode_labels(ner_labels, self.max_word_len)
        char_ids = encode_chars(correction, self.char2idx, self.max_char_len)
        return (torch.tensor(word_ids, dtype=torch.long),
                torch.tensor(ner_ids, dtype=torch.long),
                torch.tensor(char_ids, dtype=torch.long))

dataset = MultiTaskDataset(toy_data, word2idx, char2idx, max_word_len, max_char_len)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        return torch.matmul(attn_probs, V)

    def split_heads(self, x):
        batch_size, seq_length, _ = x.size()
        x = x.view(batch_size, seq_length, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def combine_heads(self, x):
        batch_size, num_heads, seq_length, d_k = x.size()
        x = x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        return x

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        return self.W_o(self.combine_heads(attn_output))

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn))
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        self_attn = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn))
        cross_attn = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(cross_attn))
        ffn_out = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_out))
        return x

class MultiTaskTransformer(nn.Module):
    def __init__(self, word_vocab_size, char_vocab_size, d_model, num_heads, num_layers, d_ff,
                 max_word_len, max_char_len, dropout, word_pad_idx, char_pad_idx):
        super(MultiTaskTransformer, self).__init__()
        self.word_pad_idx = word_pad_idx
        self.char_pad_idx = char_pad_idx
        self.max_word_len = max_word_len
        self.max_char_len = max_char_len

        # Encoder: word embeddings + positional encoding.
        self.word_embedding = nn.Embedding(word_vocab_size, d_model, padding_idx=word_pad_idx)
        self.pos_enc = PositionalEncoding(d_model, max_word_len)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        # NER head: per-word classification into 3 classes (O, B-COMPANY, I-COMPANY)
        self.ner_classifier = nn.Linear(d_model, 3)

        # Decoder: character embeddings + positional encoding.
        self.char_embedding = nn.Embedding(char_vocab_size, d_model, padding_idx=char_pad_idx)
        self.pos_enc_dec = PositionalEncoding(d_model, max_char_len)
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, char_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_masks(self, word_ids, tgt_ids):
        # Encoder mask: (batch, 1, 1, word_seq_len)
        enc_mask = (word_ids != self.word_pad_idx).unsqueeze(1).unsqueeze(2)
        # Decoder mask: combine padding mask and subsequent mask.
        tgt_mask = (tgt_ids != self.char_pad_idx).unsqueeze(1).unsqueeze(3)
        seq_len = tgt_ids.size(1)
        nopeak_mask = torch.triu(torch.ones((1, seq_len, seq_len), device=tgt_ids.device), diagonal=1).bool()
        tgt_mask = tgt_mask & ~nopeak_mask
        return enc_mask, tgt_mask

    def forward(self, word_ids, tgt_ids=None):
        # --- Encoder: Word-level ---
        enc_emb = self.dropout(self.pos_enc(self.word_embedding(word_ids)))
        enc_mask = (word_ids != self.word_pad_idx).unsqueeze(1).unsqueeze(2)
        enc_output = enc_emb
        for layer in self.encoder_layers:
            enc_output = layer(enc_output, enc_mask)

        # --- NER Head ---
        ner_logits = self.ner_classifier(enc_output)  # shape: (batch, word_seq_len, 3)
        ner_probs = torch.softmax(ner_logits, dim=-1)
        company_prob = ner_probs[:, :, 1:3].sum(dim=-1, keepdim=True)  # (batch, word_seq_len, 1)
        company_sum = (enc_output * company_prob).sum(dim=1, keepdim=True)
        denom = company_prob.sum(dim=1, keepdim=True) + 1e-8
        company_summary = company_sum / denom
        enc_extended = torch.cat([enc_output, company_summary], dim=1)
        extra_mask = torch.ones((enc_mask.size(0), 1, 1, 1), device=enc_mask.device).bool()
        enc_mask_extended = torch.cat([enc_mask, extra_mask], dim=-1)

        # --- Decoder: Character-based Correction ---
        correction_logits = None
        if tgt_ids is not None:
            enc_mask_dec, tgt_mask = self.generate_masks(word_ids, tgt_ids)
            dec_emb = self.dropout(self.pos_enc_dec(self.char_embedding(tgt_ids)))
            dec_output = dec_emb
            for layer in self.decoder_layers:
                dec_output = layer(dec_output, enc_extended, enc_mask_extended, tgt_mask)
            correction_logits = self.fc_out(dec_output)
        return ner_logits, correction_logits

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters.
d_model = 128
num_heads = 8
num_layers = 4
d_ff = 512
dropout = 0.1

model = MultiTaskTransformer(
    word_vocab_size=word_vocab_size,
    char_vocab_size=char_vocab_size,
    d_model=d_model,
    num_heads=num_heads,
    num_layers=num_layers,
    d_ff=d_ff,
    max_word_len=max_word_len,
    max_char_len=max_char_len,
    dropout=dropout,
    word_pad_idx=WORD_PAD_IDX,
    char_pad_idx=CHAR_PAD_IDX
).to(device)

# Loss functions
criterion_ner = nn.CrossEntropyLoss(ignore_index=WORD_PAD_IDX)
criterion_corr = nn.CrossEntropyLoss(ignore_index=CHAR_PAD_IDX)

# Dynamic loss balancing parameters
log_var_ner = torch.zeros(1, requires_grad=True, device=device)
log_var_corr = torch.zeros(1, requires_grad=True, device=device)

# Optimizer
optimizer = optim.Adam(list(model.parameters()) + [log_var_ner, log_var_corr], lr=1e-4)

# Curriculum learning: train correction first, then NER
total_epochs = 600
curriculum_epochs = 200  # For the first 20 epochs, only correction loss is used.
max_grad_norm = 1.0

# Training Loop
model.train()
for epoch in range(total_epochs):
    epoch_loss = 0.0
    for word_ids, ner_labels, tgt_ids in dataloader:
        word_ids = word_ids.to(device)
        ner_labels = ner_labels.to(device)
        tgt_ids = tgt_ids.to(device)

        optimizer.zero_grad()


        loss_ner = criterion_ner(ner_logits.view(-1, 3), ner_labels.view(-1))

        # Compute correction loss (no change needed)
        loss_corr = criterion_corr(corr_logits.view(-1, char_vocab_size), tgt_ids[:, 1:].contiguous().view(-1)) # decoder target

        
        ner_logits, corr_logits = model(word_ids, tgt_ids[:, :-1])

        # Compute NER loss
        loss_ner = criterion_ner(ner_logits.view(-1, 3), ner_labels.view(-1))

        # Compute correction loss
        #loss_corr = criterion_corr(corr_logits.view(-1, char_vocab_size), tgt_ids[:, 1:].contiguous().view(-1))

        loss_corr = criterion_corr(corr_logits.view(-1, char_vocab_size), tgt_ids[:, 1:].contiguous().view(-1))

        # Dynamic loss balancing (uncertainty weighting)
        if epoch < curriculum_epochs:
            total_loss = torch.exp(-log_var_corr) * loss_corr + log_var_corr
        else:
            total_loss = (torch.exp(-log_var_corr) * loss_corr + log_var_corr +
                          torch.exp(-log_var_ner) * loss_ner + log_var_ner)

        total_loss.backward()
        clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()

        epoch_loss += total_loss.item()

    avg_loss = epoch_loss / len(dataloader)
    phase = "Correction-only" if epoch < curriculum_epochs else "Joint training"
    print(f"Epoch {epoch+1}/{total_epochs} [{phase}], Loss: {avg_loss:.4f} | "
          f"w_ner: {log_var_ner.item():.4f}, w_corr: {log_var_corr.item():.4f}")
1 Like