DPR Context tokenization in a GPU

Hello everyone, I hope you are all having fun using the Hugging face library.

I am tokenizing the 8.8M passages from MSMARCO dataset. Moreover, I have indexed the dataset with Hugging face dataset because I want to add a FAISS index over it afterwards.

To do all of these, I created the dataset correctly by following these steps. Afterwards, I ran this code:

torch.set_grad_enabled(False)
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
dataset_embedded_passages = dataset_passages.map(lambda example: {'embeddings': ctx_encoder(**ctx_tokenizer(example["passage"], return_tensors="pt"))[0][0].numpy()})

This will take a lot of time in the CPU since the dataset has ~8.8M passages. Is it possible to do the tokenization in the GPU? I checked the .map() method and did not find a way to put to the device.

Hi !
You can indeed put the tokenized text on GPU and give it to the model.
Also you can make it significantly faster by using a batched map:

def embed(examples):
    tokenized_examples = ctx_tokenizer(
        examples["passage"],
        return_tensors="pt",
        padding="longest",
        truncation=True,
        max_length=512
    ).to(device=ctx_encoder.device)
    embeddings = ctx_encoder(**tokenized_examples)[0]
    return {"embeddings": embeddings}

dataset_embedded_passages = dataset_passages.map(embed, batched=True, batch_size=16)

Let me know if it helps :slight_smile:

1 Like

Hi Quentin! Thank you for your reply.

I have two more questions regarding the map method and caching.

  1. On DPR paper they used batch size = 4 for larger datasets like NQ, Trivia and SQuAD. It’s probably better to use this batch size as well since I’m working with such a big dataset like MSMARCO, do you agree?

  2. When I run the .map() method it appears a message in a logger saying: Caching processed dataset at....
    Could you explain what is happening behind the scenes? Are you bringing data from disk to RAM and cache for faster read accesses?

  1. batch size = 4 is reasonable as well. I guess that it is set to this value in the paper for training. As you’re doing inference you can increase it if you want to.

  2. Datasets are stored on disk and the library uses memory mapping to access them. This allows to load huge datasets at high speed without filling the RAM. The caching is also done on disk: when you call map, the new dataset is written on your disk and calling map again later will reuse the same file.

Hope this answers your questions :slight_smile:

Thank you, I will increase the batch size then.

That’s great. I really noticed the data being loaded quickly.
By the way, is there a similar efficient mechanism for loading huge datasets when using your FAISS index library? (with datasets.Dataset.add_faiss_index() and get_nearest_examples)