Slow in generating train split when loading local dataset

I have a dataset with 500 labels. Each label has around 20G raw image data(100k+ rows). So the whole dataset is like
./dataset/label1/data-00000-of-00001.arrow, ./dataset/label2/data-00000-of-00001.arrow ...

For the training, I want each batch can randomly sample 256 images from a label, and randomly select batch_size of labels. The dim of output of data loader is (bs, num_sample, 3, H, W).

class Dataset(data.Dataset):
    def __init__(self, dataset_path, num_sample=256, seed=666, num_workers=1):
        """
        read dataset
        """
        self.dataset_path = dataset_path
        self.labels = list_labels(dataset_path)
        self.num_sample = num_sample
        self.num_workers = num_workers
        self.seed = seed
        print(f'total subject id is {len(self.labels)}')

    def __getitem__(self, index):
        """
        Randomly sample a list of frames from a random subject id
        """
        gids = random.choices(self.labels, k=1)
        files = [get_arrow_files(os.path.join(self.dataset_path, gid)) for gid in gids]
        files = [item for sublist in files for item in sublist]
        data_files = {"train": files}
        ds = load_dataset("arrow", split='train', data_files=data_files)
        ds = ds.shuffle(seed=self.seed)
        ds = ds.to_iterable_dataset(num_shards=self.num_workers)  # num_shards should be equal to num_workers

        # ds = Dataset.from_file(files[0])
        img_feats, labels = [], [], [], []
        for idx, batch in enumerate(ds):
            # load row here
            .......
            if (idx + 1) == self.num_sample:
                break
        return data

    def __len__(self):
        # No actual length since we are randomly sampling
        return 999999


class DataLoader(data.DataLoader):
    def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0):
        super(DataLoader, self).__init__(
            dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
            collate_fn=self.collate_fn)
        self.num_class = len(dataset.labels)
        self.modality_type = modality_type

    def collate_fn(self, batch):
        """
        Custom collate function to handle multiple inputs.
        """
        # (Bs, N, 3, 112, 112)
        img_feats, labels = zip(*batch)

       # preprocess feature here
       .......

        return ( img_feats, labels)

When I ran this code in the local machine with only 2 labels’ data in the dataset and tested with bs=2, num_sample=32, it ran fast. But when I moved to the server with 522 labels’ data in the dataset and tested with bs=16, num_sample=32, it started to generate the train split first:

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 12.35it/s]

Generating train split: 0 examples [00:00, ? examples/s]
...
Generating train split: 67378 examples [00:18, 3652.42 examples/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]
Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 9446.63it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 10.50it/s]

Generating train split: 0 examples [00:00, ? examples/s]
Generating train split: 1000 examples [00:00, 3054.78 examples/s]
Generating train split: 2000 examples [00:00, 2939.25 examples/s]
Generating train split: 3000 examples [00:01, 2917.95 examples/s]
Generating train split: 4000 examples [00:01, 3083.18 examples/s]
...

It became really slow in loading the dataset, and I noticed that this will process a huge cache file in the system. I’m new to huggingface and I’m not sure if there’s any optimization or solution to it.

Thanks!

Hi ! You can pass a HF Dataset directly to your data loader, after defining your logic:

def group(batch):
    # group a batch of 256 labels as one example
    return {k: [v] for k, v in batch.items()}

labels_datasets = [load_dataset("arrow", split='train', data_files=data_files_for_label) for data_files_for_label in data_files_for_labels]

# each batch can randomly sample 256 images from a label, and randomly select batch_size of labels
grouped = [ds.shuffle().map(group, batched=True, batch_size=256) for ds in labels_datasets]
interleaved = interleave_datasets(grouped, probabilities=[1/len(grouped)] * len(grouped))
dataloader = DataLoader(interleaved, batch_size=batch_size)