Hi there,I wasn’t quite sure whether to post this shit in intermediate or beginner channel but I guess I’ve made my choice
I’m working on a similarity task where, for each example, I have an anchor, a query, a positive, and a negative.
The goal is to train a model that gives a higher similarity score to the positive response than to the negative one.
Here’s what I’ve implemented so far:
* I use a T5 model to generate a response to the query based on the anchor.
* I then encode the generated response using BERT (taking the [CLS] token embedding).
* Finally, I compute a triplet loss between:
* the generated (anchor) embedding
* the positive embedding
* the negative embedding
The problem is… the model doesn’t seem to learn anything at all, the loss doesn’t decrease.
If anyone can help me understand what I might be doing wrong, or point me in the right direction, I’d really appreciate it
Here is the class model :
class CustomModel:
def __init__(self, truncate=True, path='/content/dataset_big_patent_v3.json',
bert_model_name="bert-base-uncased", margin=0.3, device=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
self.bert = BertModel.from_pretrained(bert_model_name).to(self.device)
self.t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
self.T5 = T5ForConditionalGeneration.from_pretrained("t5-small").to(self.device)
self.triplet_loss_fn = torch.nn.TripletMarginLoss(margin=margin, p=2)
self.load_data(path)
if truncate:
self.data["anchor"] = self.data["anchor"].apply(self.truncate_text)
self.data['inputs'] = 'question: ' + self.data['query'] + '\ncontext: ' + self.data['anchor']
self.dataset = TripletTextDataset(self.data)
def load_data(self, path):
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
self.data = pd.DataFrame(data)
def truncate_text(self, text):
text_lower = text.lower()
point = next((p for p in (
text_lower.find('technical field'),
text_lower.find('invention'),
text_lower.find('disclosure'),
text_lower.find('this application')
) if p != -1), None)
return text[point:] if point is not None else text
def get_embedding(self, text, requires_grad=False):
tokens = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(self.device)
if requires_grad:
outputs = self.bert(**tokens)
else:
with torch.no_grad():
outputs = self.bert(**tokens)
return outputs.last_hidden_state[:, 0, :].squeeze(0)
def fit(self, optimizer, batch_size=8, epochs=10):
self.T5.train()
self.bert.eval()
dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):
total_loss = 0.0
print(f"\nEpoch {epoch + 1}/{epochs}")
for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}"):
input_texts = batch['anchor']
tokenized = self.t5_tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(self.device)
generated_ids = self.T5.generate(
**tokenized,
num_beams=4,
max_length=256,
early_stopping=True
)
generated_texts = [self.t5_tokenizer.decode(g, skip_special_tokens=True) for g in generated_ids]
# Get anchor embeddings with grad
anchor_emb = torch.stack([self.get_embedding(text, requires_grad=True) for text in generated_texts])
# Get positive and negative embeddings without grad
positive_emb = torch.stack([self.get_embedding(p) for p in batch['positive']])
negative_emb = torch.stack([self.get_embedding(n) for n in batch['negative']])
optimizer.zero_grad()
loss = self.triplet_loss_fn(anchor_emb, positive_emb, negative_emb)
loss.backward()
optimizer.step()
total_loss += loss.item()
mean_loss = total_loss / len(dataloader)
print(f"Mean triplet loss for epoch {epoch + 1}: {mean_loss:.4f}")
return mean_loss