Hello HuggingFace community
I recently created a streaming dataset that is sharded into 8 parquet files for the training set and 1 parquet file for both validation and testing, each. When I iterate through the entire set, it of course triggers a StopIteration exception and continues to the next epoch, but the problem is that the dataset is stuck on that StopIteration. Otherwise put, if I put the command next(iter(dataset[‘train’])) after an epoch, it still says StopIteration and cannot reset so that I can run it for another epoch.
I have implemented a near identical dataloader in a linux environment without such an issue, but now I am attempting this in a windows environment and encountering these issues. Does anyone know what I am doing wrong?
Take a look at the code below. The main point is that it fails after 1 full run through the iterable dataset, i.e. on the second epoch.
class LoadObj:
def __init__(self,
#dictionary_path: str,
dobj: DicObj,
dataset_path: dict=None,
batch_size: int=100,
embed: bool=False,
num_workers: int=0,
buffer_size: int=1000,
**kwargs
):
self.D = dobj
self.embed = embed
self.channels = dobj.seq_channels if embed else dobj.channels
# Dataset
if dataset_path is not None:
dataset = load_dataset(
'parquet',
data_files=dataset_path,
streaming=True
)
# Filter for length
dataset = dataset.filter(
self.filter_charge_and_length
)
# Map to format outputs
dataset = dataset.map(
self.map_fn,
remove_columns=['sequence', 'charge', 'mod_pos', 'mod_type', 'nce', 'ev', 'ion', 'ab']
)
# Shuffle dataset
dataset['train'] = dataset['train'].shuffle(buffer_size=buffer_size)
self.dataset = dataset
def build_dataloader(dataset, batch_size, num_workers):
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=self.collate_fn,
)
# Dataloaders
num_workers = min(self.dataset['train'].n_shards, num_workers)
self.dataloader = {
'train': build_dataloader(dataset['train'], batch_size, num_workers),
'val': build_dataloader(dataset['val'] , batch_size, 0),
'test': build_dataloader(dataset['test'] , batch_size, 0),
}
"""
Windows multiprocessing prevents using lambda functions for map and filter
AttributeError: Can't pickle local object 'LoaderHF.__init__.<locals>.<lambda>'
See this message board: https://discuss.pytorch.org/t/cant-pickle-local-object-dataloader-init-locals-lambda/31857
"""
def filter_charge_and_length(self, example):
boolean = (
(example['charge'] >= self.D.chlim[0]) &
(example['charge'] <= self.D.chlim[1]) &
(len(example['sequence']) <= self.D.seq_len)
)
return boolean
"""
Saving large sparse arrays to parquet file is very slow to
read in (e.g. target vector). Rather one should save the data
in the most dense form possible and process it after being read in.
"""
def map_fn(self, example):
input_tensor = torch.zeros((self.channels, self.D.seq_len), dtype=torch.float32)
# Sequence
seq = example['sequence']
assert len(seq) <= self.D.seq_len, "Exceeded maximum peptide length."
input_tensor[:len(self.D.dic), :len(seq)] = self.sequence_one_hot(seq)
input_tensor[len(self.D.dic)-1, len(seq):] = 1.
# PTMs
input_tensor[len(self.D.dic)] = 1.
if example['mod_type'][0] != 'null':
for pos, modtyp in zip(example['mod_pos'], example['mod_type']):
input_tensor[self.D.mdic[modtyp], int(pos)] = 1.
input_tensor[len(self.D.dic), int(pos)] = 0.
# Charge
charge = example['charge']
input_tensor[self.D.seq_channels+charge-1] = 1.
# eV
input_tensor[-1, :] = example['ev'] / 100.
output_tensor = torch.zeros((len(self.D.dictionary),), dtype=torch.float32)
for ion, ab in zip(example['ion'], example['ab']):
output_tensor[self.D.dictionary[ion]] = ab
output_tensor /= output_tensor.max()
example['input_tensor'] = input_tensor
example['target_tensor'] = output_tensor
return example
def collate_fn(self, batch_list):
inp = torch.stack([m['input_tensor'] for m in batch_list])
out = torch.stack([m['target_tensor'] for m in batch_list])
return inp, out
def sequence_one_hot(self, sequence):
return torch.nn.functional.one_hot(
torch.tensor([self.D.dic[o] for o in sequence], dtype=torch.long),
len(self.D.dic)
).T
L = LoadObj(
dataset_path={'train': "input_data/datasets/*train*parquet", 'val': "input_data/datasets/*val*parquet", 'test': "input_data/datasets/*test*parquet"},
dobj=D,
embed=False,
batch_size=100,
num_workers=8,
remove_columns=['ab','ion','ev','nce','charge','sequence'],
)
for epoch in range(2):
for i, batch in enumerate(L.dataloader['val']):
print("\r%d"%i, end='')
# Will fail at the start of epoch==1 due to StopIteration exception