Issue with iterable dataset that is stuck on StopIteration

Hello HuggingFace community

I recently created a streaming dataset that is sharded into 8 parquet files for the training set and 1 parquet file for both validation and testing, each. When I iterate through the entire set, it of course triggers a StopIteration exception and continues to the next epoch, but the problem is that the dataset is stuck on that StopIteration. Otherwise put, if I put the command next(iter(dataset[‘train’])) after an epoch, it still says StopIteration and cannot reset so that I can run it for another epoch.

I have implemented a near identical dataloader in a linux environment without such an issue, but now I am attempting this in a windows environment and encountering these issues. Does anyone know what I am doing wrong?

Take a look at the code below. The main point is that it fails after 1 full run through the iterable dataset, i.e. on the second epoch.

class LoadObj:
    def __init__(self,
        #dictionary_path: str,
        dobj: DicObj,
        dataset_path: dict=None,
        batch_size: int=100,
        embed: bool=False,
        num_workers: int=0,
        buffer_size: int=1000,
        **kwargs
    ):
        self.D = dobj
        self.embed = embed
        self.channels = dobj.seq_channels if embed else dobj.channels
        
        # Dataset
        if dataset_path is not None:
            dataset = load_dataset(
                'parquet',
                data_files=dataset_path,
                streaming=True
            )
            
            # Filter for length
            dataset = dataset.filter(
                self.filter_charge_and_length
            )
            
            # Map to format outputs
            dataset = dataset.map(
                self.map_fn,
                remove_columns=['sequence', 'charge', 'mod_pos', 'mod_type', 'nce', 'ev', 'ion', 'ab']
            )
            
            # Shuffle dataset
            dataset['train'] = dataset['train'].shuffle(buffer_size=buffer_size)
            
            self.dataset = dataset
            
            def build_dataloader(dataset, batch_size, num_workers):
                return DataLoader(
                    dataset,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    collate_fn=self.collate_fn,   
                )
            
            # Dataloaders
            num_workers = min(self.dataset['train'].n_shards, num_workers)
            self.dataloader = {
                'train': build_dataloader(dataset['train'], batch_size, num_workers),
                'val':   build_dataloader(dataset['val']  , batch_size, 0),
                'test':  build_dataloader(dataset['test'] , batch_size, 0),
            }
    
    """
    Windows multiprocessing prevents using lambda functions for map and filter
    
    AttributeError: Can't pickle local object 'LoaderHF.__init__.<locals>.<lambda>'
    
    See this message board: https://discuss.pytorch.org/t/cant-pickle-local-object-dataloader-init-locals-lambda/31857
    """
    
    def filter_charge_and_length(self, example):
        boolean = (
            (example['charge'] >= self.D.chlim[0]) &
            (example['charge'] <= self.D.chlim[1]) &
            (len(example['sequence']) <= self.D.seq_len)
        )
        return boolean
    
    """
    Saving large sparse arrays to parquet file is very slow to 
    read in (e.g. target vector). Rather one should save the data 
    in the most dense form possible and process it after being read in.
    """
    def map_fn(self, example):
        input_tensor = torch.zeros((self.channels, self.D.seq_len), dtype=torch.float32)
        # Sequence
        seq = example['sequence']
        assert len(seq) <= self.D.seq_len, "Exceeded maximum peptide length."
        input_tensor[:len(self.D.dic), :len(seq)] = self.sequence_one_hot(seq)
        input_tensor[len(self.D.dic)-1, len(seq):] = 1.
        # PTMs
        input_tensor[len(self.D.dic)] = 1.
        if example['mod_type'][0] != 'null':
            for pos, modtyp in zip(example['mod_pos'], example['mod_type']):
                input_tensor[self.D.mdic[modtyp], int(pos)] = 1.
                input_tensor[len(self.D.dic), int(pos)] = 0.
        # Charge
        charge = example['charge']
        input_tensor[self.D.seq_channels+charge-1] = 1.
        # eV
        input_tensor[-1, :] = example['ev'] / 100.
        
        output_tensor = torch.zeros((len(self.D.dictionary),), dtype=torch.float32)
        for ion, ab in zip(example['ion'], example['ab']):
            output_tensor[self.D.dictionary[ion]] = ab
        output_tensor /= output_tensor.max()
        
        example['input_tensor'] = input_tensor
        example['target_tensor'] = output_tensor
        
        return example

    def collate_fn(self, batch_list):
        inp = torch.stack([m['input_tensor'] for m in batch_list])
        out = torch.stack([m['target_tensor'] for m in batch_list])
        return inp, out

    def sequence_one_hot(self, sequence):
        return torch.nn.functional.one_hot(
            torch.tensor([self.D.dic[o] for o in sequence], dtype=torch.long),
            len(self.D.dic)
        ).T

L = LoadObj(
        dataset_path={'train': "input_data/datasets/*train*parquet", 'val': "input_data/datasets/*val*parquet", 'test': "input_data/datasets/*test*parquet"},
        dobj=D, 
        embed=False,
        batch_size=100,
        num_workers=8,
        remove_columns=['ab','ion','ev','nce','charge','sequence'],
    )

for epoch in range(2):
    for i, batch in enumerate(L.dataloader['val']):
        print("\r%d"%i, end='')
   # Will fail at the start of epoch==1 due to StopIteration exception

Hmm I’m not sure what could be causing this. Would you be able to make a minimal reproducible example ? Maybe without map/filter/dataloader ? This would help figuring out the culprit

Hey Quentin, thanks for responding

Before I go to the trouble of reproducing a minimal example, would it help you to know that this issue occurred in a windows environment, but when I cloned the code on a Linux cluster and ran, it worked perfectly as intended?

Tbh even with this info I’m not sure. I was thinking of a side-effect of a DataLoader worker but you’re using 0 workers for the “val” set.

My second guess was a dataset resuming bug (related to state_dict() and load_state_dict()) but dataset resuming should have the same behavior on linux and windows so it’s unlikely it…

unless you’re not using the same version of datasets in your linux env and windows env ? Dataset resuming was added pretty recently