I currently have a (huggingface pretrained distillbert) pytorch model which is finetuned on my own transaction data, for the purpose of classifying transactions as one of 5 classes. I currently extract the embeddings, then append some other extracted features (documents have time and date as well as a brief description), and finally run it through a final classification layer.
However, there is a lot of valuable information if I cluster them based on description by customer (vs treating them as isolated transactions), such as time between transactions and other descriptive statistics. I am trying to think of how I can approach this, given that the model is currently classifying record for record and how this should fit in with the model architecture.
One thought I have is to represent each cluster with the mean embedding and descriptive statistics, and then train a classifier on this level instead. That is, I would
- get transformer embeddings for each transaction
- cluster similar transaction descriptions together in a post processing step
- Derive descriptive statistics for each cluster
- Classify at the cluster level
My main question is, do I have to split this into two models (an embedding extracting model, then a post process clustering step, then a classifier model), or could I somehow achieve this in a single model? But any thoughts/input is welcome. Thanks!