FAISS indexing for MARCO dataset

Hey everyone,

I’m trying to create a FAISS Index for the MS-MARCO dataset and I’m following the documentation provided here.

I’m trying to understand if there is a way to create the Faiss index in a much more batch effective way. The current worked out example seems to be taking each example and encoding it one by one and I’m not sure if this is the only way to do it, or if datasets has some functionality that can make this go faster.

The reason I’m asking is because the expected time show to index “just” the training data is around 530 hours on a GPU Colab notebook.

Any insight on this would be appreciated.

This is the code snippet that I’ve been working with:

!pip install transformers datasets faiss-gpu
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast

import torch

torch.set_grad_enabled(False)

ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
from datasets import load_dataset
ds = load_dataset('ms_marco', 'v2.1', split='train')

ds_with_embeddings = ds.map(lambda example: {'embeddings': ctx_encoder(**ctx_tokenizer(example['passages']['passage_text'], return_tensors="pt", padding="longest"))[0][0].numpy()})
ds_with_embeddings.add_faiss_index(column='embeddings')

ds_with_embeddings.save_faiss_index('embeddings', 'drive/MyDrive/marco.faiss.train')

There are two things you can do AFAIK:

  • use multiprocessing in map by setting num_proc > 0
  • use batching (probably the biggest bottleneck) by setting batched=True and batch_size to a reasonable amount.

E.g. something like this (untested; you may need to change some things here and there):

ds_with_embeddings = ds.map(lambda batch: {'embeddings': ctx_encoder(**ctx_tokenizer(batch['passages']['passage_text'],
                                                                                     return_tensors="np",
                                                                                     padding="longest"))[0][0] },
                            batched=True,
                            batch_size=64,
                            num_proc=6)

ds_with_embeddings = ds.map(lambda example: {‘embeddings’: ctx_encoder(**ctx_tokenizer(example[‘passages’][‘passage_text’], return_tensors=“pt”, padding=“longest”))[0][0].numpy()})

Thanks @BramVanroy for helping out with this. I guess I figured out the batching part once I went through the documentation a bit more carefully. I do however have a followup:

I moved the models and inputs to GPU hoping that would be faster, but it seems that doesn’t really work with multiprocessing?

Could you tell me if there’s something really obvious I’m missing here?

from transformers import DPRContextEncoder, DPRContextEncoderTokenizer

import torch

torch.set_grad_enabled(False)

ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

ctx_encoder = ctx_encoder.to('cuda:0')

from datasets import load_from_disk
# a version of the MARCO dataset that only has passage text
ds = load_from_disk("drive/MyDrive/marco")
print(ctx_encoder.device)

def encode(example):
    npys = ctx_encoder(**ctx_tokenizer(example['passages'], return_tensors="pt", padding="longest", truncation="longest_first").to('cuda:0'))[0].cpu().numpy()
return {'embeddings': npys}

ds_with_embeddings = ds.map(encode, batched=True, batch_size=100, num_proc=6)

This fails with the following error on a GPU colab device:
RuntimeError: CUDA error: initialization error

Any help would be appreciated! Thanks a lot!

It is likely that the multiprocessing step does not work well with GPU-accelerated tasks as that means duplicating the whole main process - and you’ll run out of memory. (Not sure why it runs into an initialization error though). Does it work without num_proc?

It’s possible that you should use tokenizer.batch_encode_plus here because you are passing a batch.