My model doesn't learn with my triplet loss

Hi there,I wasn’t quite sure whether to post this shit in intermediate or beginner channel but I guess I’ve made my choice :face_in_clouds:

I’m working on a similarity task where, for each example, I have an anchor, a query, a positive, and a negative.
The goal is to train a model that gives a higher similarity score to the positive response than to the negative one.

Here’s what I’ve implemented so far:

* I use a T5 model to generate a response to the query based on the anchor.
* I then encode the generated response using BERT (taking the [CLS] token embedding).
* Finally, I compute a triplet loss between:
  * the generated (anchor) embedding
  * the positive embedding
  * the negative embedding

The problem is… the model doesn’t seem to learn anything at all, the loss doesn’t decrease.

If anyone can help me understand what I might be doing wrong, or point me in the right direction, I’d really appreciate it :folded_hands:

Here is the class model :

class CustomModel:
    def __init__(self, truncate=True, path='/content/dataset_big_patent_v3.json',
                 bert_model_name="bert-base-uncased", margin=0.3, device=None):

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device

        self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
        self.bert = BertModel.from_pretrained(bert_model_name).to(self.device)

        self.t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
        self.T5 = T5ForConditionalGeneration.from_pretrained("t5-small").to(self.device)

        self.triplet_loss_fn = torch.nn.TripletMarginLoss(margin=margin, p=2)

        self.load_data(path)
        if truncate:
            self.data["anchor"] = self.data["anchor"].apply(self.truncate_text)

        self.data['inputs'] = 'question: ' + self.data['query'] + '\ncontext: ' + self.data['anchor']
        self.dataset = TripletTextDataset(self.data)

    def load_data(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        self.data = pd.DataFrame(data)

    def truncate_text(self, text):
        text_lower = text.lower()
        point = next((p for p in (
            text_lower.find('technical field'),
            text_lower.find('invention'),
            text_lower.find('disclosure'),
            text_lower.find('this application')
        ) if p != -1), None)
        return text[point:] if point is not None else text

    def get_embedding(self, text, requires_grad=False):
        tokens = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(self.device)
        if requires_grad:
            outputs = self.bert(**tokens)
        else:
            with torch.no_grad():
                outputs = self.bert(**tokens)
        return outputs.last_hidden_state[:, 0, :].squeeze(0)

    def fit(self, optimizer, batch_size=8, epochs=10):
        self.T5.train()
        self.bert.eval()
        dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):
            total_loss = 0.0
            print(f"\nEpoch {epoch + 1}/{epochs}")

            for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}"):
                input_texts = batch['anchor']
                tokenized = self.t5_tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(self.device)
                generated_ids = self.T5.generate(
                    **tokenized,
                    num_beams=4,
                    max_length=256,
                    early_stopping=True
                )
                generated_texts = [self.t5_tokenizer.decode(g, skip_special_tokens=True) for g in generated_ids]

                # Get anchor embeddings with grad
                anchor_emb = torch.stack([self.get_embedding(text, requires_grad=True) for text in generated_texts])

                # Get positive and negative embeddings without grad
                positive_emb = torch.stack([self.get_embedding(p) for p in batch['positive']])
                negative_emb = torch.stack([self.get_embedding(n) for n in batch['negative']])

                optimizer.zero_grad()
                loss = self.triplet_loss_fn(anchor_emb, positive_emb, negative_emb)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            mean_loss = total_loss / len(dataloader)
            print(f"Mean triplet loss for epoch {epoch + 1}: {mean_loss:.4f}")

        return mean_loss
2 Likes

If learning efficiency is poor, then parameters may be suspect, but if there is no change at all, then there may be a problem with the model code itself…


by Hugging Chat: HuggingChat

Your model may not be learning during training due to several potential issues with the model setup and training process. Here are the possible reasons and solutions based on the provided code and reasoning:

  1. BERT Model is in Eval Mode: In your fit method, you have self.bert.eval() which freezes the BERT model’s parameters. BERT is responsible for generating embeddings, and if it’s frozen, it won’t learn from the triplet loss. You should train BERT alongside T5 by removing self.bert.eval() and including BERT’s parameters in the optimizer.

  2. Incorrect Parameter Optimization: The optimizer might not be including all the necessary parameters. Ensure that both the T5 and BERT models’ parameters are passed to the optimizer.

  3. Triplet Loss Implementation: The triplet loss might not be effectively training the model. The way you generate triplets (anchors, positives, negatives) could be flawed. Ensure that the triplets are properly formed and that the model is learning discriminative features.

  4. Code Structure and Initialization: The CustomModel class may have issues with how the model is initialized. Ensure that all components (BERT, T5, etc.) are properly initialized and moved to the correct device.

  5. Batch Size and Training Parameters: A small batch size (batch_size=8) and fewer epochs (10) might not be sufficient for the model to learn effectively. Consider increasing the batch size if possible and training for more epochs.

  6. Data Preprocessing: The way you preprocess the data (e.g., truncation) might affect the quality of the embeddings. Ensure that the truncation and padding are handled correctly.

  7. Device Handling: Ensure that all tensors are moved to the correct device (GPU if available) consistently.

Solution Code

Here is a corrected version of your CustomModel class and training setup:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5ForConditionalGeneration
from sklearn.metrics import accuracy_score
import json
import pandas as pd
from tqdm import tqdm

class CustomModel:
    def __init__(self, truncate=True, path='/content/dataset_big_patent_v3.json',
                 bert_model_name="bert-base-uncased", margin=0.3, device=None):

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device

        self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
        self.bert = AutoModel.from_pretrained(bert_model_name).to(self.device)

        self.t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
        self.T5 = T5ForConditionalGeneration.from_pretrained("t5-small").to(self.device)

        self.triplet_loss_fn = torch.nn.TripletMarginLoss(margin=margin, p=2)

        self.load_data(path)
        if truncate:
            self.data["anchor"] = self.data["anchor"].apply(self.truncate_text)

        self.data['inputs'] = 'question: ' + self.data['query'] + '\ncontext: ' + self.data['anchor']
        self.dataset = TripletTextDataset(self.data)

    def load_data(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        self.data = pd.DataFrame(data)

    def truncate_text(self, text):
        text_lower = text.lower()
        point = next((p for p in (
            text_lower.find('technical field'),
            text_lower.find('invention'),
            text_lower.find('disclosure'),
            text_lower.find('this application')
        ) if p != -1), None)
        return text[point:] if point is not None else text

    def get_embedding(self, text, requires_grad=True):
        tokens = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(self.device)
        if requires_grad:
            outputs = self.bert(**tokens)
        else:
            with torch.no_grad():
                outputs = self.bert(**tokens)
        return outputs.last_hidden_state[:, 0, :].squeeze(0)

    def fit(self, optimizer, batch_size=32, epochs=20):
        self.T5.train()
        self.bert.train()  # Enable training for BERT

        dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):
            total_loss = 0.0
            print(f"\nEpoch {epoch + 1}/{epochs}")

            for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}"):
                input_texts = batch['anchor']
                tokenized = self.t5_tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(self.device)
                generated_ids = self.T5.generate(
                    **tokenized,
                    num_beams=4,
                    max_length=256,
                    early_stopping=True
                )
                generated_texts = [self.t5_tokenizer.decode(g, skip_special_tokens=True) for g in generated_ids]

                # Get anchor embeddings with grad
                anchor_emb = torch.stack([self.get_embedding(text, requires_grad=True) for text in generated_texts])

                # Get positive and negative embeddings without grad
                positive_emb = torch.stack([self.get_embedding(p) for p in batch['positive']])
                negative_emb = torch.stack([self.get_embedding(n) for n in batch['negative']])

                optimizer.zero_grad()
                loss = self.triplet_loss_fn(anchor_emb, positive_emb, negative_emb)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            mean_loss = total_loss / len(dataloader)
            print(f"Mean triplet loss for epoch {epoch + 1}: {mean_loss:.4f}")

        return mean_loss

# Example usage:
# model = CustomModel()
# optimizer = optim.AdamW(model.parameters(), lr=2e-5)
# model.fit(optimizer, batch_size=32, epochs=20)

Explanation

  1. Train BERT: Removed self.bert.eval() and set self.bert.train() to enable gradient flow through BERT.
  2. Include All Parameters: The optimizer now includes parameters from both the T5 and BERT models.
  3. Increase Batch Size: Changed the default batch_size from 8 to 32 to provide better gradient estimates.
  4. Adjust Training Parameters: Increased the number of epochs to 20 for better learning.
  5. Use Correct Embeddings: Ensure that the embeddings are correctly computed and passed to the triplet loss function.

By making these adjustments, the model should learn more effectively during training.

1 Like

Yep, thanks John, it s learning fairly well !

But if I also train BERT, I feel like I’m drifting away from my initial goal, which is to generate a response that’s close to the positive example and far from the negative one.
By training BERT, it feels like I’m artificially lowering the loss, since the model is also learning how to represent those elements better.
I was actually thinking of BERT more like an external evaluator, rather than something that’s truly part of the model being optimized. Not sure though…

1 Like