Iām trying to create a FAISS Index for the MS-MARCO dataset and Iām following the documentation provided here.
Iām trying to understand if there is a way to create the Faiss index in a much more batch effective way. The current worked out example seems to be taking each example and encoding it one by one and Iām not sure if this is the only way to do it, or if datasets has some functionality that can make this go faster.
The reason Iām asking is because the expected time show to index ājustā the training data is around 530 hours on a GPU Colab notebook.
Any insight on this would be appreciated.
This is the code snippet that Iāve been working with:
Thanks @BramVanroy for helping out with this. I guess I figured out the batching part once I went through the documentation a bit more carefully. I do however have a followup:
I moved the models and inputs to GPU hoping that would be faster, but it seems that doesnāt really work with multiprocessing?
Could you tell me if thereās something really obvious Iām missing here?
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
torch.set_grad_enabled(False)
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_encoder = ctx_encoder.to('cuda:0')
from datasets import load_from_disk
# a version of the MARCO dataset that only has passage text
ds = load_from_disk("drive/MyDrive/marco")
print(ctx_encoder.device)
def encode(example):
npys = ctx_encoder(**ctx_tokenizer(example['passages'], return_tensors="pt", padding="longest", truncation="longest_first").to('cuda:0'))[0].cpu().numpy()
return {'embeddings': npys}
ds_with_embeddings = ds.map(encode, batched=True, batch_size=100, num_proc=6)
This fails with the following error on a GPU colab device:
RuntimeError: CUDA error: initialization error
It is likely that the multiprocessing step does not work well with GPU-accelerated tasks as that means duplicating the whole main process - and youāll run out of memory. (Not sure why it runs into an initialization error though). Does it work without num_proc?
Itās possible that you should use tokenizer.batch_encode_plus here because you are passing a batch.