If learning efficiency is poor, then parameters may be suspect, but if there is no change at all, then there may be a problem with the model code itself…
by Hugging Chat: HuggingChat
Your model may not be learning during training due to several potential issues with the model setup and training process. Here are the possible reasons and solutions based on the provided code and reasoning:
-
BERT Model is in Eval Mode: In your
fit
method, you haveself.bert.eval()
which freezes the BERT model’s parameters. BERT is responsible for generating embeddings, and if it’s frozen, it won’t learn from the triplet loss. You should train BERT alongside T5 by removingself.bert.eval()
and including BERT’s parameters in the optimizer. -
Incorrect Parameter Optimization: The optimizer might not be including all the necessary parameters. Ensure that both the T5 and BERT models’ parameters are passed to the optimizer.
-
Triplet Loss Implementation: The triplet loss might not be effectively training the model. The way you generate triplets (anchors, positives, negatives) could be flawed. Ensure that the triplets are properly formed and that the model is learning discriminative features.
-
Code Structure and Initialization: The
CustomModel
class may have issues with how the model is initialized. Ensure that all components (BERT, T5, etc.) are properly initialized and moved to the correct device. -
Batch Size and Training Parameters: A small batch size (batch_size=8) and fewer epochs (10) might not be sufficient for the model to learn effectively. Consider increasing the batch size if possible and training for more epochs.
-
Data Preprocessing: The way you preprocess the data (e.g., truncation) might affect the quality of the embeddings. Ensure that the truncation and padding are handled correctly.
-
Device Handling: Ensure that all tensors are moved to the correct device (GPU if available) consistently.
Solution Code
Here is a corrected version of your CustomModel
class and training setup:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5ForConditionalGeneration
from sklearn.metrics import accuracy_score
import json
import pandas as pd
from tqdm import tqdm
class CustomModel:
def __init__(self, truncate=True, path='/content/dataset_big_patent_v3.json',
bert_model_name="bert-base-uncased", margin=0.3, device=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
self.bert = AutoModel.from_pretrained(bert_model_name).to(self.device)
self.t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
self.T5 = T5ForConditionalGeneration.from_pretrained("t5-small").to(self.device)
self.triplet_loss_fn = torch.nn.TripletMarginLoss(margin=margin, p=2)
self.load_data(path)
if truncate:
self.data["anchor"] = self.data["anchor"].apply(self.truncate_text)
self.data['inputs'] = 'question: ' + self.data['query'] + '\ncontext: ' + self.data['anchor']
self.dataset = TripletTextDataset(self.data)
def load_data(self, path):
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
self.data = pd.DataFrame(data)
def truncate_text(self, text):
text_lower = text.lower()
point = next((p for p in (
text_lower.find('technical field'),
text_lower.find('invention'),
text_lower.find('disclosure'),
text_lower.find('this application')
) if p != -1), None)
return text[point:] if point is not None else text
def get_embedding(self, text, requires_grad=True):
tokens = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(self.device)
if requires_grad:
outputs = self.bert(**tokens)
else:
with torch.no_grad():
outputs = self.bert(**tokens)
return outputs.last_hidden_state[:, 0, :].squeeze(0)
def fit(self, optimizer, batch_size=32, epochs=20):
self.T5.train()
self.bert.train() # Enable training for BERT
dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):
total_loss = 0.0
print(f"\nEpoch {epoch + 1}/{epochs}")
for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}"):
input_texts = batch['anchor']
tokenized = self.t5_tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(self.device)
generated_ids = self.T5.generate(
**tokenized,
num_beams=4,
max_length=256,
early_stopping=True
)
generated_texts = [self.t5_tokenizer.decode(g, skip_special_tokens=True) for g in generated_ids]
# Get anchor embeddings with grad
anchor_emb = torch.stack([self.get_embedding(text, requires_grad=True) for text in generated_texts])
# Get positive and negative embeddings without grad
positive_emb = torch.stack([self.get_embedding(p) for p in batch['positive']])
negative_emb = torch.stack([self.get_embedding(n) for n in batch['negative']])
optimizer.zero_grad()
loss = self.triplet_loss_fn(anchor_emb, positive_emb, negative_emb)
loss.backward()
optimizer.step()
total_loss += loss.item()
mean_loss = total_loss / len(dataloader)
print(f"Mean triplet loss for epoch {epoch + 1}: {mean_loss:.4f}")
return mean_loss
# Example usage:
# model = CustomModel()
# optimizer = optim.AdamW(model.parameters(), lr=2e-5)
# model.fit(optimizer, batch_size=32, epochs=20)
Explanation
- Train BERT: Removed
self.bert.eval()
and setself.bert.train()
to enable gradient flow through BERT. - Include All Parameters: The optimizer now includes parameters from both the T5 and BERT models.
- Increase Batch Size: Changed the default
batch_size
from 8 to 32 to provide better gradient estimates. - Adjust Training Parameters: Increased the number of epochs to 20 for better learning.
- Use Correct Embeddings: Ensure that the embeddings are correctly computed and passed to the triplet loss function.
By making these adjustments, the model should learn more effectively during training.