Fetching data takes too too much time

class SQUAD(Dataset):
    def __init__(self):
        # Load our training dataset and tokenizer
        self.dataset = load_dataset("squad", split="train")
        self.encoded_context = self.dataset.map(
            convert_to_features_context, batched=True
        )
        #self.dataset = self.dataset.flatten()
        self.encoded_question = self.dataset.map(
            convert_to_features_question, batched=True
        )
        # Format our dataset to outputs torch.Tensor to train a pytorch model
        columns = ["input_ids", "start_positions", "end_positions"]
        self.encoded_context.set_format(type="torch", columns=columns)
        self.encoded_question.set_format(type="torch", columns=["input_ids"])
        self.length = len(self.encoded_context["input_ids"])
        self.encoded_context.flatten()
        self.encoded_question.flatten()

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        t1 = time.time()
        input_ids_context = self.encoded_context["input_ids"][idx]
        print("context", time.time() - t1)  #0.38 sec      
        t2 = time.time()
        input_ids_question = self.encoded_question["input_ids"][idx]
        print("question", time.time() - t2) #0.40 sec
        t3 = time.time()
        input_ids_start_positions = self.encoded_context["start_positions"][idx]
        print("start", time.time() - t3)  #0.40 sec
        t4 = time.time()
        input_ids_end_positions = self.encoded_context["end_positions"][idx]
        print("end", time.time() - t4)  #0.68 sec
        
        return (
            input_ids_context,
            input_ids_question,
            input_ids_start_positions,
            input_ids_end_positions,
        )

# in main function

train_dataset = SQUAD()
dataloader = DataLoader(train_dataset, batch_size=20)
for batch in dataloader :   
         ...

I made my own custom dataset and I brought Squad dataset from Huggingface.

The problem is fetching data takes too much time.

In getitem function, returning 4 values takes about 1.8 seconds.

when batch size =20, it takes about 1.8 * 20 = 36 seconds.

It looks like this code has a huge problem.

Is there any workaround?

2 Likes

Answered here.

1 Like