Column Name Mismatch Error while Streaming?

While streaming FineWeb (see code below) I get a “casting” error (see trace below) due to mismatched column names. It specifically looks like “filter_reason” is missing from the latest data chunk.

Throwing errors when there are mismatches during streaming makes sense as a desired behavior, but is there a built-in way of ignoring that error or skipping mismatched chunks of data?

Code:

if __name__ == "__main__":
    SEED = 0
    SAMPLE_BUFFER_SIZE=5_000
    RECORDS_TO_KEEP= 100_000
    TAKE_SIZE = 10_000_000 # 23,355,019,906 is max size

    fw = load_dataset("HuggingFaceFW/fineweb", split="train", streaming=True)
    fw = fw.shuffle(seed=SEED, buffer_size=SAMPLE_BUFFER_SIZE)
    clf = pickle.load(open('dataset_differentiator.pkl','rb'))
    priority_queue = PriorityQueue(RECORDS_TO_KEEP,key=lambda x: x['prob_control'])
    for sample in tqdm(fw.take(TAKE_SIZE)):
        # this is the domain prediction model, I can share more code if it seems relevant
        prediction = do_prediction_here(sample)
        priority_queue.add_record(prediction)
        
  
    json.dump(priority_queue.get_records(), open('sampled_features_100k.json', 'w'))

Error Trace

176147160it [69:33:58, 811.53it/s]Failed to read file 'hf://datasets/HuggingFaceFW/fineweb@29be36a2e035737f9b2d7e4f0847413ff7b2994b/data/CC-MAIN-2024-18/002_00009.parquet' with error <class 'd
atasets.table.CastError'>: Couldn't cast                                                                                                                                                        
text: string   
id: string                                                                                                                                                                             [32/1908]
dump: string                                                                                                                                                                                    
url: string                                                                                                                                                                                     
date: string                                                                                                                                                                                    
file_path: string                                                                                                                                                                               
language: string                                                                                                                                                                                
language_score: double                                                                                                                                                                          
filter_reason: string                                                                                                                                                                           
token_count: int64                                                                                                                                                                              
to                                                                                                                                                                                              
{'text': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'dump': Value(dtype='string', id=None), 'url': Value(dtype='string', id=None), 'date': Value(dtype='string', id=N
one), 'file_path': Value(dtype='string', id=None), 'language': Value(dtype='string', id=None), 'language_score': Value(dtype='float64', id=None), 'token_count': Value(dtype='int64', id=None)} 
because column names don't match                                                                                                                                                                
176147212it [69:33:59, 703.35it/s]                                                                                                                                                              
Traceback (most recent call last):                                                                                                                                                              
  File "/home/felix_l_labelle/neoberta/dataset_filtering/fineweb_curation.py", line 88, in <module>                                                                                             
    for sample in tqdm(dl):                                                                                                                                                                     
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/tqdm/std.py", line 1181, in __iter__                                                                  
    for obj in iterable:                                                                                                                                                                        
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 631, in __next__                                                
    data = self._next_data()                                                                                                                                                                    
           ^^^^^^^^^^^^^^^^^                                                                                                                                                                    
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data                                             
    return self._process_data(data)                                                                                                                                                             
           ^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                             
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data                                          
    data.reraise()                                                                                                                                                                              
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/torch/_utils.py", line 704, in reraise                                                                
    raise RuntimeError(msg) from None                                                                                                                                                           
RuntimeError: Caught CastError in DataLoader worker process 12.                                                                                                                                 
Original Traceback (most recent call last):     
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]                                                                                                                     [0/1908]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 41, in fetch
    data = next(self.dataset_iter)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 1368, in __iter__
    yield from self._iter_pytorch()
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 1303, in _iter_pytorch
    for key, example in ex_iterable:
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 1044, in __iter__
    yield from islice(self.ex_iterable, self.n)
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 987, in __iter__
    for x in self.ex_iterable:
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 282, in __iter__
    for key, pa_table in self.generate_tables_fn(**self.kwargs):
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/datasets/packaged_modules/parquet/parquet.py", line 97, in _generate_tables
    yield f"{file_idx}_{batch_idx}", self._cast_table(pa_table)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/datasets/packaged_modules/parquet/parquet.py", line 75, in _cast_table
    pa_table = table_cast(pa_table, self.info.features.arrow_schema)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/datasets/table.py", line 2295, in table_cast
    return cast_table_to_schema(table, schema)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix_l_labelle/anaconda3/envs/data_processing/lib/python3.11/site-packages/datasets/table.py", line 2249, in cast_table_to_schema
    raise CastError(
datasets.table.CastError: Couldn't cast
text: string
id: string
dump: string
url: string
date: string
file_path: string
language: string
language_score: double
filter_reason: string
token_count: int64
to
{'text': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'dump': Value(dtype='string', id=None), 'url': Value(dtype='string', id=None), 'date': Value(dtype='string', id=None), 'file_path': Value(dtype='string', id=None), 'language': Value(dtype='string', id=None), 'language_score': Value(dtype='float64', id=None), 'token_count': Value(dtype='int64', id=None)}
because column names don't match