Hi!
I am considering using Transformers as “featurizer” in a scikit-learn NLP pipeline, to generate document embeddings and take advantage of the convenient functionality offered by sklearn to perform cross-validation etc., ease of swapping pipeline components with those provided by sklearn.feature_extraction
, which I would have a hard time in recoding in pure PyTorch.
To do this, I created a custom Estimator
class BertEmbedder(BaseEstimator, TransformerMixin):
def __init__(self):
model = AutoModel.from_pretrained("bert-base-uncased", max_length=512)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", model_max_length=512, padding=True, truncation=True)
self.embedder = pipeline("feature-extraction", model=model, tokenizer=tokenizer, device=0, model_max_length=512)
def fit(self, X, y=None):
return self
def transform(self, raw_documents):
# Perform arbitary transformation
dataset = DatasetDict()
dataset = Dataset.from_dict({"Text" : raw_documents})
embeddings = self.embedder(dataset["Text"], truncation=True, batch_size=4)
document_embeddings = []
for emb in embeddings:
x = np.array(emb).mean(axis = 1).squeeze()
document_embeddings.append(x)
return np.array(document_embeddings)
and create a sklearn pipeline (6.1. Pipelines and composite estimators — scikit-learn 1.1.1 documentation, here plugging a SVM on top of the BERT featurizer for illustrative purposes).
parameters = {
"classifier__alpha": np.logspace(-1, 0, 1),
}
pipe = Pipeline(
[
("vectorizer", BertEmbedder()),
(
"classifier",
SGDClassifier(
loss="modified_huber",
penalty="l2",
alpha=5e-3,
random_state=42,
max_iter=100,
tol=None,
),
),
],
)
gs = GridSearchCV(pipe, parameters, n_jobs=1, verbose=2, scoring="accuracy")
gs.fit(train["text"].values, train["class"].values)
Is the above an OK thing to do?
Note that in the above, when I run CV, pipe
recomputes document embeddings over and over, which is something one would like to avoid e.g. by caching results of the featurization.
Do you have suggestions for a better implementation (I am not sure if I could take advantage of libs like skorch documentation — skorch 0.11.0 documentation or PyTorch Lightning to try out different classification heads, and if the mixing of different frameworks an antipattern to avoid!)?