Distilbert-base-nli-stsb-mean-tokens OOM encoding sentences of 100K docs

Hi,

Using sentence-transformers/distilbert-base-nli-stsb-mean-tokens to embed sentences from corpus of 100K academic articles. Model is defined as below:

`self.model = ‘sentence-transformers/distilbert-base-nli-stsb-mean-tokens’
self.word_embedding_model = models.BERT(
self.model,
max_seq_length=128,
do_lower_case=True)

                self.pooling_model = models.Pooling(self.word_embedding_model.get_word_embedding_dimension(),
                                                    pooling_mode_mean_tokens=True,
                                                    pooling_mode_cls_token=False,
                                                    pooling_mode_max_tokens=False)

                self.model = SentenceTransformer(modules=[self.word_embedding_model, self.pooling_model])

self.corpus_embeddings = self.model.encode(self.corpus)

Running with 64GB ram with 3090FE (24GB vram), the encoding task makes it ~50% through before running out of memory.

Most grateful for any guidance on how I might be able to handle encoding of entire corpus - chunking it up, reduce model size (and best approach to that).

Many thanks

Do you want one vector for the whole corpus, one per sentence, or what exactly do you want? What is inside that corpus variable?

Currently I have one vector per sentence (self.corpus_embeddings). I then use the same model to embed a query phrase. Then use cosine similarity to rank corpus embeddings with query:

co_dist = scipy.spatial.distance.cdist(query_embeddings, self.corpus_embeddings, "cosine")[0]

First things first, AFAIK you do not do all that hassle to do pooling. The sentence transformers model sentence-transformers/distilbert-base-nli-stsb-mean-tokens already pools over the tokens so you already just get one output per input sentence. So

self.model = SentenceTransformer("sentence-transformers/distilbert-base-nli-stsb-mean-tokens")

Second, I am still not sure what is inside self.corpus but I guess it contains all the sentences (List[str]). Converting all sentences of 100K articles is not a memory-lenient task.

Third, for what you want to do it is probably best to use something like FAISS for querying vectors. You’re probably better off incrementally creating a FAISS index by doing a batch loop outside the sentence transformers encode and adding each resulting batch to the index. Perhaps you can even use the FAISS capabilities that are embedded in the datasets library of HF. Example here.

For future reference: we are glad to help here when you use the HF repositories, but if you use third party libraries (such as sentence transformers) then you should ask a question on their channels, not here.

1 Like

Thanks very much for the tips and guidance @BramVanroy , after a scan of the datasets library that looks very useful indeed, thankyou again!