Hi,
Thank you for your answer and sorry for the late reply (got distracted by work, life, etc).
I have read/watched some of the resources you sent (this video in particular is really nice: https://www.youtube.com/watch?v=uk6SlTzfbUY) and I now have a basic grasp of how positive unlabelled learning works.
I have implemented two approaches with the following algorithms:
- OneClassSVM
- WeightedElkanotoPuClassifier
Since last time, I built a very modest dataset of “bad” articles: articles I don’t want to read, I don’t find them interesting. I have labelled 70 of them, I intend to use them in my validation set.
OneClassSVM
My approach is:
- load 7465 “good” articles (the ones I read, the ones I find interesting)
- compute embeddings with all-MiniLM-L12-v2 for good articles
- train classifier on good embeddings
- prepare 100 good articles and 70 bad articles (none of them was used during training)
- compute precision on validation set:
(# of correct good + # of correct bad) / (total good + total bad)
During validation:
- if an article is in fact good and the model gives a score > 0.5 → +1
- if an article is in fact good and the model gives a score < 0.5 → 0
Same for bad.
WeightedElkanotoPuClassifier
My approach is:
- load 7465 “good” articles (the ones I read, the ones I find interesting)
- load 7000 unlabelled articles (they could be good or bad)
- compute embeddings with all-MiniLM-L12-v2 for good and unlabelled articles
- train classifier on good and unlabelled embeddings
- prepare 100 good articles and 70 bad articles (none of them was used during training)
- compute precision on validation set:
(# of correct good + # of correct bad) / (total good + total bad)
Results
I got insane results and they feel too good to be true:
- OneClassSVM: 92%
- WeightedElkanotoPuClassifier: 98%
Questions
- Does it look sensible to you?
- Would you have any tip?
- Do I measure the precision correctly? Should I use another metric?
NOTE: I have done a bit of parameter tuning on the OneClassSVM but not on the WeightedElkanotoPuClassifier.
Code
OneClassSVM
import asyncio
import numpy as np
from bs4 import BeautifulSoup
from cleantext import clean
from sentence_transformers import SentenceTransformer
# from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler
from sklearn.svm import OneClassSVM
from feedoscope.data_registry import data_registry as dr
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
def strip_html_keep_text(html: str) -> str:
soup = BeautifulSoup(html, "html.parser")
text = soup.get_text(separator=" ", strip=True)
return " ".join(text.split())
def compute_embeddings(model, texts: list[str]):
embeddings = model.encode(
texts, show_progress_bar=True, normalize_embeddings=True, convert_to_numpy=True
)
return embeddings
def prepare_articles_text(articles) -> list[str]:
texts = []
for a in articles:
text = clean(
strip_html_keep_text(f"{a['feed_name']} {a['title']} {a['content']}")
)
texts.append(text)
return texts
def normalize_scores(scores):
scaler = MinMaxScaler()
return scaler.fit_transform(scores.reshape(-1, 1)).flatten()
def ocsvm_score(estimator, X):
# Higher decision_function means more inlier-like
return np.mean(estimator.decision_function(X))
async def main() -> None:
print("Loading SentenceTransformer model...")
model = SentenceTransformer(MODEL_NAME)
print("Model loaded successfully.")
print("Collecting articles from the database...")
await dr.global_pool.open(wait=True)
articles = await dr.get_articles()
print(f"Collected {len(articles)} articles.")
print("Computing embeddings for articles...")
embeddings = compute_embeddings(model, prepare_articles_text(articles))
print(f"Computed embeddings for {len(embeddings)} articles.")
# Use best parameters directly
ocsvm = OneClassSVM(kernel="linear", gamma="scale", nu=0.2)
ocsvm.fit(embeddings)
# # Hyperparameter tuning for OneClassSVM
# param_grid = {
# "kernel": ["rbf", "linear", "sigmoid"],
# "gamma": ["scale", "auto", 0.01, 0.1, 1],
# "nu": [0.01, 0.05, 0.1, 0.2]
# }
# print("Tuning OneClassSVM hyperparameters...")
# ocsvm = OneClassSVM()
# grid = GridSearchCV(
# OneClassSVM(),
# param_grid,
# cv=3,
# n_jobs=-1,
# scoring=ocsvm_score
# )
# grid.fit(embeddings)
# best_ocsvm = grid.best_estimator_
# print("Best parameters:", grid.best_params_)
not_good_sample = await dr.get_sample_not_good()
not_good_embeddings = compute_embeddings(
model, prepare_articles_text(not_good_sample)
)
raw_scores = ocsvm.decision_function(not_good_embeddings)
scores = normalize_scores(raw_scores)
correct_not_good, total_good = sum(s <= 0.5 for s in scores), len(scores)
good_sample = await dr.get_sample_good()
good_embeddings = compute_embeddings(model, prepare_articles_text(good_sample))
raw_scores = ocsvm.decision_function(good_embeddings)
scores = normalize_scores(raw_scores)
correct_good, total_not_good = sum(s > 0.5 for s in scores), len(scores)
print(
f"Overall precision: {(correct_good + correct_not_good) / (total_good + total_not_good):.2f}"
)
if __name__ == "__main__":
asyncio.run(main())
WeightedElkanotoPuClassifier
import asyncio
import numpy as np
from bs4 import BeautifulSoup
from cleantext import clean
from pulearn import WeightedElkanotoPuClassifier
from sentence_transformers import SentenceTransformer
from sklearn.svm import SVC
from feedoscope.data_registry import data_registry as dr
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
def strip_html_keep_text(html: str) -> str:
soup = BeautifulSoup(html, "html.parser")
text = soup.get_text(separator=" ", strip=True)
return " ".join(text.split())
def compute_embeddings(model, texts: list[str]):
embeddings = model.encode(
texts, show_progress_bar=True, normalize_embeddings=True, convert_to_numpy=True
)
return embeddings
def prepare_articles_text(articles) -> list[str]:
texts = []
for a in articles:
text = clean(
strip_html_keep_text(f"{a['feed_name']} {a['title']} {a['content']}")
)
texts.append(text)
return texts
async def main() -> None:
print("Loading SentenceTransformer model...")
model = SentenceTransformer(MODEL_NAME)
print("Model loaded successfully.")
print("Collecting articles from the database...")
await dr.global_pool.open(wait=True)
articles = await dr.get_articles()
print(f"Collected {len(articles)} articles.")
print("Computing embeddings for articles...")
embeddings = compute_embeddings(model, prepare_articles_text(articles))
print(f"Computed embeddings for {len(embeddings)} articles.")
print("Collecting unread articles from the database...")
await dr.global_pool.open(wait=True)
unlabeled_articles = await dr.get_unread_articles()
print(f"Collected {len(unlabeled_articles)} unread articles.")
print("Computing embeddings for unread articles...")
unlabeled_embeddings = compute_embeddings(
model, prepare_articles_text(unlabeled_articles)
)
print(f"Computed embeddings for {len(unlabeled_embeddings)} unread articles.")
# Combine embeddings and labels for PU learning
X = np.concatenate([embeddings, unlabeled_embeddings], axis=0)
y = np.concatenate(
[np.ones(len(embeddings)), np.zeros(len(unlabeled_embeddings))], axis=0
)
print("Fitting PU classifier...")
# Takes a while for 7k + 7k articles
svc = SVC(C=10, kernel="rbf", gamma=0.4, probability=True)
# svc = SVC(C=10, kernel='linear', gamma='scale', probability=True)
pu_estimator = WeightedElkanotoPuClassifier(
estimator=svc,
labeled=len(embeddings),
unlabeled=len(unlabeled_embeddings),
hold_out_ratio=0.2,
)
pu_estimator.fit(X, y)
print("PU classifier fitted successfully.")
not_good_sample = await dr.get_sample_not_good()
not_good_embeddings = compute_embeddings(
model, prepare_articles_text(not_good_sample)
)
scores = pu_estimator.predict_proba(not_good_embeddings)[:, 1]
correct_not_good, total_good = sum(s <= 0.5 for s in scores), len(scores)
good_sample = await dr.get_sample_good()
good_embeddings = compute_embeddings(model, prepare_articles_text(good_sample))
scores = pu_estimator.predict_proba(good_embeddings)[:, 1]
correct_good, total_not_good = sum(s > 0.5 for s in scores), len(scores)
print(
f"Overall precision: {(correct_good + correct_not_good) / (total_good + total_not_good):.2f}"
)
breakpoint()
if __name__ == "__main__":
asyncio.run(main())