Extracting and adding document clustering features to a document classification model

I currently have a (huggingface pretrained distillbert) pytorch model which is finetuned on my own transaction data, for the purpose of classifying transactions as one of 5 classes. I currently extract the embeddings, then append some other extracted features (documents have time and date as well as a brief description), and finally run it through a final classification layer.

However, there is a lot of valuable information if I cluster them based on description by customer (vs treating them as isolated transactions), such as time between transactions and other descriptive statistics. I am trying to think of how I can approach this, given that the model is currently classifying record for record and how this should fit in with the model architecture.

One thought I have is to represent each cluster with the mean embedding and descriptive statistics, and then train a classifier on this level instead. That is, I would

  1. get transformer embeddings for each transaction
  2. cluster similar transaction descriptions together in a post processing step
  3. Derive descriptive statistics for each cluster
  4. Classify at the cluster level

My main question is, do I have to split this into two models (an embedding extracting model, then a post process clustering step, then a classifier model), or could I somehow achieve this in a single model? But any thoughts/input is welcome. Thanks!