Iterating on dataset extremely slow

I’ve tried a few different approaches to iterate on 1M records with datasets but all seem quite slow. So this time I’m trying the recommended streaming method from the recent post but still it’s extremely slow. It seems for dataset there will always be multiple ways to iterate but most (or all) of them will not yield satisfiable speed.

import s3fs
s3 = s3fs.S3FileSystem()

import datasets as hfdt
from torch.utils.data import DataLoader
SEED = 0
SAMPLE_BUFFER_SIZE=5_000
RECORDS_TO_KEEP= 100_000
TAKE_SIZE = 1_000_000 

dt_train = hfdt.load_dataset('parquet', data_files = ['s3://' + i for i in s3.glob('s3://my_s3_bucket/*')], streaming=True)
dt_train = dt_train['train'].shuffle(seed=SEED, buffer_size=SAMPLE_BUFFER_SIZE)

import torch
import pandas as pd
def custom_collate_fn(batch):
    # batch = list(filter(lambda x: x is not None, batch))  # Remove None values
    batch = [ 0 for i in batch]
    return torch.utils.data.dataloader.default_collate(batch)  # Use the default collate function

dl_train = DataLoader(dt_train.take(TAKE_SIZE), num_workers=8, prefetch_factor=2, batch_size=128, collate_fn=custom_collate_fn)

from tqdm import tqdm
for sample in tqdm(dl_train):
    pass

The s3 folder has 200 parquet files and in total it has 1M rows and 28 columns. However, for such a small dataset, it takes forever to iterate on the data:

Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 27.81it/s]
5381it [1:24:45,  1.25it/s]

It took almost 2 hours to finish:

real    118m42.879s
user    421m2.584s
sys     0m55.513s

I’m wondering if HF dataset is really feasible for moderate sized dataset?

1 Like

@zhh210 The post you referenced was mine and I’m able to stream 900+ records per second from the FineWeb dataset while inferencing over the text read in with a linear model.

What type of data are you loading (images, text, tabular) and how big are these files? If increasing the number of workers didn’t increase speed that might be a sign that the bottleneck is the download (or upload on the other end) speed

@FelixLabelle It was a surprise to me as well. My parquet files (.snappy.parquet) are tabular data with only numeric values and each of the 200 parquet files is just 1.3M.

Ok, in that case I would just try substitutions to debug potential sources of the issue. My guesses of causes in order of likelihood are

  1. Upload speed limitations (the servers you are downloading from)
  2. (tied for first) Download speed limitations
  3. Something specific to your data
  4. Something in your code which is adding time unbeknownst to you

For 1. you can try downloading similar datasets from HF directly and see if you still have the issue. If you don’t I would check upload speed. I don’t use AWS, but certain storage methods can be slower to get data from and would be a bottleneck.

This test could also be indicative for cause 3. If it fails and upload looks OK, I would try streaming a dummy dataset that’s much smaller and seeing if that’s slow as well. If it, then I would start stripping away your dataset to see if it’s a specific column or number of columns causing an issue.

For 2. if you are using a computer just run a download speed test in your browser. I would also turn off any VPNs if you can, those can add time

Cause 4 can be profiled, I would try that if you haven’t already. That would likely be required anyway if you’ve gotten this far debugging.

If that doesn’t narrow down the problem LMK, we can look at the other causes

1 Like

Tried loading the dataset from local instead of from s3, it does reduce the running time from 2 hours to 1 hour:

CPU times: user 9.21 s, sys: 2.87 s, total: 12.1 s
Wall time: 1h 51s

profiling shows that most of the time were spent on select.poll(), so i guess this means it’s slow reading .snappy.parquet file:

7811it [2:33:58,  1.18s/it]
         3624330 function calls (3624329 primitive calls) in 9238.469 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.899    0.899 9238.469 9238.469 415229234.py:3(run_local)
     7871    0.022    0.000    0.030    0.000 <frozen importlib._bootstrap>:404(parent)
     8357    0.012    0.000    0.019    0.000 <string>:1(<lambda>)
        1    0.000    0.000 9238.469 9238.469 <string>:1(<module>)
     7812    0.003    0.000    0.003    0.000 __init__.py:127(annotate)
        8    0.000    0.000    0.000    0.000 __init__.py:219(_acquireLock)
        8    0.000    0.000    0.000    0.000 __init__.py:228(_releaseLock)
        1    0.000    0.000    0.000    0.000 __init__.py:9(is_available)
     7812    0.004    0.000    0.004    0.000 _jit_internal.py:1120(is_scripting)
        1    0.000    0.000    0.000    0.000 _monitor.py:94(report)
     7812    0.015    0.000    1.027    0.000 _ops.py:591(__call__)
     7812    0.018    0.000    0.479    0.000 _ops.py:846(__call__)
     7811    0.048    0.000    0.409    0.000 _utils.py:178(_rebuild_tensor)
        1    0.000    0.000    0.000    0.000 _weakrefset.py:111(remove)
        2    0.000    0.000    0.000    0.000 _weakrefset.py:17(__init__)
        2    0.000    0.000    0.000    0.000 _weakrefset.py:21(__enter__)
        2    0.000    0.000    0.000    0.000 _weakrefset.py:27(__exit__)
        8    0.000    0.000    0.000    0.000 _weakrefset.py:39(_remove)
        2    0.000    0.000    0.000    0.000 _weakrefset.py:53(_commit_removals)
        3    0.000    0.000    0.000    0.000 _weakrefset.py:63(__iter__)
       17    0.000    0.000    0.000    0.000 _weakrefset.py:86(add)
        4    0.000    0.000    0.000    0.000 abc.py:117(__instancecheck__)
     7829    0.025    0.000    0.025    0.000 connection.py:117(__init__)
     7813    0.005    0.000    0.005    0.000 connection.py:130(__del__)
    87005    0.050    0.000    0.050    0.000 connection.py:134(_check_closed)
    39601    0.017    0.000    0.017    0.000 connection.py:138(_check_readable)
    31244    0.017    0.000    0.017    0.000 connection.py:142(_check_writable)
    16160    0.019    0.000    0.026    0.000 connection.py:168(fileno)
     7811    0.020    0.000    0.105    0.000 connection.py:173(close)
    23433    0.150    0.000    0.580    0.000 connection.py:181(send_bytes)
     7811    0.041    0.000    0.455    0.000 connection.py:202(send)
    31252    0.118    0.000   96.117    0.003 connection.py:208(recv_bytes)
     8349    0.038    0.000 9091.500    1.089 connection.py:253(poll)
     7811    0.003    0.000    0.003    0.000 connection.py:259(__enter__)
     7811    0.014    0.000    0.118    0.000 connection.py:262(__exit__)
     7813    0.011    0.000    0.085    0.000 connection.py:360(_close)
    31244    0.052    0.000    0.355    0.000 connection.py:365(_send)
    62504    0.251    0.000   95.719    0.002 connection.py:374(_recv)
    31244    0.096    0.000    0.484    0.000 connection.py:390(_send_bytes)
    31252    0.168    0.000   95.958    0.003 connection.py:413(_recv_bytes)
     8349    0.045    0.000 9091.452    1.089 connection.py:423(_poll)
     7811    0.074    0.000   97.772    0.013 connection.py:493(Client)
        9    0.000    0.000    0.001    0.000 connection.py:516(Pipe)
     7811    0.108    0.000    0.597    0.000 connection.py:623(SocketClient)
     7811    0.096    0.000   31.872    0.004 connection.py:732(deliver_challenge)
     7811    0.128    0.000   65.169    0.008 connection.py:747(answer_challenge)
     7811    0.017    0.000    0.017    0.000 connection.py:83(_validate_family)
     8357    0.136    0.000 9091.431    1.088 connection.py:917(wait)
     7827    0.011    0.000    0.011    0.000 connection.py:933(<listcomp>)
    15622    0.035    0.000    0.050    0.000 connection.py:95(address_type)
        9    0.001    0.000    0.033    0.004 context.py:100(Queue)
       33    0.000    0.000    0.000    0.000 context.py:187(get_context)
       32    0.000    0.000    0.000    0.000 context.py:197(get_start_method)
        8    0.000    0.000    0.076    0.009 context.py:222(_Popen)
       18    0.000    0.000    0.000    0.000 context.py:237(get_context)
        8    0.001    0.000    0.076    0.009 context.py:278(_Popen)
       19    0.000    0.000    0.023    0.001 context.py:65(Lock)
        1    0.000    0.000    0.000    0.000 context.py:75(Condition)
        4    0.000    0.000    0.000    0.000 context.py:80(Semaphore)
        9    0.000    0.000    0.001    0.000 context.py:85(BoundedSemaphore)
        1    0.000    0.000    0.000    0.000 context.py:90(Event)
        9    0.000    0.000    0.000    0.000 dataloader.py:1082(<genexpr>)
        1    0.000    0.000    0.066    0.066 dataloader.py:1087(_reset)
        1    0.000    0.000    0.000    0.000 dataloader.py:1103(<listcomp>)
     8349    0.047    0.000 9230.746    1.106 dataloader.py:1120(_try_get_data)
     7819    0.029    0.000 9230.775    1.181 dataloader.py:1266(_get_data)
     7812    0.119    0.000 9231.913    1.182 dataloader.py:1299(_next_data)
     7835    0.066    0.000    1.019    0.000 dataloader.py:1348(_try_put_index)
     7811    0.019    0.000    0.976    0.000 dataloader.py:1368(_process_data)
        8    0.000    0.000    0.000    0.000 dataloader.py:1375(_mark_worker_as_unavailable)
        2    0.000    0.000    0.027    0.014 dataloader.py:1401(_shutdown_workers)
        1    0.000    0.000    0.000    0.000 dataloader.py:1478(__del__)
        1    0.000    0.000    0.180    0.180 dataloader.py:382(_get_iterator)
        1    0.000    0.000    0.000    0.000 dataloader.py:389(multiprocessing_context)
        1    0.000    0.000    0.180    0.180 dataloader.py:426(__iter__)
        2    0.000    0.000    0.000    0.000 dataloader.py:441(_auto_collation)
        1    0.000    0.000    0.000    0.000 dataloader.py:445(_index_sampler)
        1    0.000    0.000    0.000    0.000 dataloader.py:457(__len__)
        1    0.000    0.000    0.000    0.000 dataloader.py:486(check_worker_number_rationality)
        1    0.000    0.000    0.001    0.001 dataloader.py:565(__init__)
        1    0.000    0.000    0.000    0.000 dataloader.py:610(_reset)
     7835    0.009    0.000    0.614    0.000 dataloader.py:620(_next_index)
     7812    0.107    0.000 9233.713    1.182 dataloader.py:626(__next__)
  1002881    0.153    0.000    0.153    0.000 dataloader.py:88(__iter__)
        1    0.000    0.000    0.000    0.000 dataloader.py:93(_get_distributed_settings)
        1    0.002    0.002    0.180    0.180 dataloader.py:991(__init__)
        1    0.000    0.000    0.000    0.000 distributed_c10d.py:460(default_pg)
        1    0.000    0.000    0.000    0.000 distributed_c10d.py:588(WORLD)
        1    0.000    0.000    0.000    0.000 distributed_c10d.py:973(is_initialized)
        1    0.000    0.000    0.000    0.000 functools.py:393(__get__)
    15622    0.012    0.000    0.012    0.000 hmac.py:139(_current)
    15622    0.028    0.000    0.089    0.000 hmac.py:151(digest)
    15622    0.035    0.000    0.260    0.000 hmac.py:167(new)
    15622    0.050    0.000    0.225    0.000 hmac.py:38(__init__)
    15622    0.058    0.000    0.167    0.000 hmac.py:66(_init_hmac)
     6160    0.010    0.000    0.010    0.000 iostream.py:138(_event_pipe)
     6160    0.030    0.000    0.233    0.000 iostream.py:259(schedule)
        1    0.000    0.000    0.000    0.000 iostream.py:364(fileno)
     2043    0.005    0.000    0.006    0.000 iostream.py:505(parent_header)
     2043    0.004    0.000    0.008    0.000 iostream.py:550(_is_master_process)
     2043    0.008    0.000    0.136    0.000 iostream.py:577(_schedule_flush)
     2059    0.029    0.000    3.042    0.001 iostream.py:592(flush)
     2043    0.043    0.000    0.198    0.000 iostream.py:655(write)
        8    0.000    0.000    0.002    0.000 ipkernel.py:768(init_closure)
        2    0.000    0.000    0.000    0.000 ipkernel.py:775(_clean_thread_parent_frames)
        1    0.000    0.000    0.000    0.000 ipkernel.py:790(<setcomp>)
        1    0.000    0.000    0.000    0.000 os.py:675(__getitem__)
        1    0.000    0.000    0.000    0.000 os.py:755(encode)
        8    0.000    0.000    0.075    0.009 popen_fork.py:15(__init__)
     4281    0.006    0.000    0.024    0.000 popen_fork.py:24(poll)
        8    0.000    0.000    0.027    0.003 popen_fork.py:36(wait)
        8    0.001    0.000    0.049    0.006 popen_fork.py:62(_launch)
        8    0.001    0.000    0.077    0.010 process.py:110(start)
        8    0.000    0.000    0.027    0.003 process.py:142(join)
     4245    0.013    0.000    0.041    0.000 process.py:153(is_alive)
        8    0.000    0.000    0.000    0.000 process.py:189(name)
        8    0.000    0.000    0.000    0.000 process.py:205(daemon)
     7811    0.010    0.000    0.010    0.000 process.py:213(authkey)
        8    0.000    0.000    0.000    0.000 process.py:234(ident)
     7843    0.006    0.000    0.006    0.000 process.py:37(current_process)
        8    0.000    0.000    0.000    0.000 process.py:61(_cleanup)
        8    0.001    0.000    0.001    0.000 process.py:80(__init__)
       16    0.000    0.000    0.000    0.000 process.py:94(<genexpr>)
     4269    0.002    0.000    0.002    0.000 process.py:99(_check_closed)
     7812    0.037    0.000    0.054    0.000 profiler.py:593(__init__)
     7812    0.038    0.000    0.517    0.000 profiler.py:604(__enter__)
     7812    0.086    0.000    1.122    0.000 profiler.py:610(__exit__)
        8    0.000    0.000    0.000    0.000 queues.py:140(close)
       16    0.000    0.000    0.000    0.000 queues.py:153(cancel_join_thread)
        8    0.001    0.000    0.058    0.007 queues.py:161(_start_thread)
        8    0.000    0.000    0.000    0.000 queues.py:204(_finalize_close)
        9    0.001    0.000    0.031    0.003 queues.py:37(__init__)
        9    0.001    0.000    0.005    0.001 queues.py:71(_reset)
     7842    0.060    0.000    0.335    0.000 queues.py:86(put)
     8349    0.135    0.000 9230.658    1.106 queues.py:98(get)
       32    0.001    0.000    0.010    0.000 random.py:506(choices)
       32    0.009    0.000    0.009    0.000 random.py:519(<listcomp>)
     7811    0.129    0.000   38.008    0.005 reduction.py:153(recvfds)
     7811    0.079    0.000   38.264    0.005 reduction.py:186(recv_handle)
     7811    0.112    0.000    0.233    0.000 reduction.py:38(__init__)
     7811    0.052    0.000    0.317    0.000 reduction.py:48(dumps)
     7811    0.076    0.000    0.484    0.000 reductions.py:109(rebuild_tensor)
     7811    0.038    0.000    0.043    0.000 reductions.py:32(__init__)
     8256    0.007    0.000    0.014    0.000 reductions.py:45(expired)
    15622    0.028    0.000    0.116    0.000 reductions.py:479(fd_id)
     7764    0.005    0.000    0.012    0.000 reductions.py:48(__del__)
     7811    0.017    0.000    0.063    0.000 reductions.py:487(storage_from_cache)
     7811    0.135    0.000  137.586    0.018 reductions.py:494(rebuild_storage_fd)
     7811    0.069    0.000    0.147    0.000 reductions.py:531(rebuild_typed_storage)
     7811    0.037    0.000    0.046    0.000 reductions.py:76(get)
     7811    0.030    0.000    0.075    0.000 reductions.py:80(__setitem__)
       64    0.013    0.000    0.039    0.001 reductions.py:86(free_dead_references)
     7811    0.072    0.000  136.913    0.018 resource_sharer.py:55(detach)
     7811    0.167    0.000   98.456    0.013 resource_sharer.py:81(get_connection)
     7837    0.443    0.000    0.596    0.000 sampler.py:274(__iter__)
     8357    0.003    0.000    0.003    0.000 selectors.py:200(__enter__)
     8357    0.015    0.000    0.053    0.000 selectors.py:203(__exit__)
     8357    0.026    0.000    0.043    0.000 selectors.py:21(_fileobj_to_fd)
     8357    0.028    0.000    0.037    0.000 selectors.py:210(__init__)
     8357    0.013    0.000    0.055    0.000 selectors.py:216(_fileobj_lookup)
     8357    0.057    0.000    0.131    0.000 selectors.py:235(register)
     8357    0.032    0.000    0.038    0.000 selectors.py:269(close)
     7827    0.015    0.000    0.015    0.000 selectors.py:276(_key_from_fd)
     8357    0.035    0.000    0.080    0.000 selectors.py:348(__init__)
     8357    0.041    0.000    0.179    0.000 selectors.py:352(register)
     8357    0.150    0.000 9090.966    1.088 selectors.py:403(select)
     8357    0.009    0.000    0.009    0.000 selectors.py:64(__init__)
        1    0.000    0.000    0.000    0.000 signal_handling.py:47(_set_SIGCHLD_handler)
        8    0.000    0.000    0.000    0.000 signal_handling.py:63(handler)
    15622    0.193    0.000    0.193    0.000 socket.py:220(__init__)
    15622    0.006    0.000    0.006    0.000 socket.py:236(__enter__)
    15622    0.019    0.000    0.068    0.000 socket.py:239(__exit__)
     7811    0.014    0.000    0.031    0.000 socket.py:494(_real_close)
     7811    0.018    0.000    0.048    0.000 socket.py:498(close)
     7811    0.027    0.000    0.032    0.000 socket.py:504(detach)
     7811    0.023    0.000    0.098    0.000 socket.py:539(fromfd)
     6160    0.159    0.000    0.159    0.000 socket.py:545(send)
     2044    0.012    0.000    0.016    0.000 std.py:102(acquire)
     2044    0.011    0.000    0.013    0.000 std.py:106(release)
        4    0.000    0.000    0.000    0.000 std.py:110(__enter__)
        4    0.000    0.000    0.000    0.000 std.py:113(__exit__)
        1    0.000    0.000    0.000    0.000 std.py:1147(__del__)
     2041    0.017    0.000    0.150    0.000 std.py:1150(__str__)
        2    0.000    0.000    0.000    0.000 std.py:1153(_comparable)
        2    0.000    0.000    0.000    0.000 std.py:1157(__hash__)
     7812    0.050    0.000 9237.568    1.182 std.py:1160(__iter__)
     2039    0.035    0.000    3.617    0.002 std.py:1198(update)
        2    0.000    0.000    0.001    0.000 std.py:1265(close)
        2    0.000    0.000    0.000    0.000 std.py:1286(fp_write)
        1    0.000    0.000    0.000    0.000 std.py:1301(<lambda>)
     2040    0.017    0.000    3.553    0.002 std.py:1325(refresh)
     2041    0.021    0.000    0.028    0.000 std.py:1446(format_dict)
     2041    0.014    0.000    3.508    0.002 std.py:1464(display)
        3    0.000    0.000    0.000    0.000 std.py:226(__init__)
    12204    0.034    0.000    0.034    0.000 std.py:231(__call__)
     4081    0.028    0.000    0.031    0.000 std.py:400(format_interval)
        1    0.000    0.000    0.002    0.002 std.py:438(status_printer)
     2041    0.013    0.000    3.241    0.002 std.py:451(fp_write)
     2041    0.013    0.000    3.341    0.002 std.py:457(print_status)
     2041    0.061    0.000    0.106    0.000 std.py:464(format_meter)
        1    0.000    0.000    0.000    0.000 std.py:663(__new__)
        1    0.000    0.000    0.000    0.000 std.py:679(_get_free_pos)
        1    0.000    0.000    0.000    0.000 std.py:682(<setcomp>)
        1    0.000    0.000    0.000    0.000 std.py:686(_decr_instances)
        1    0.000    0.000    0.000    0.000 std.py:760(get_lock)
        1    0.000    0.000    0.002    0.002 std.py:952(__init__)
     8256    0.004    0.000    0.007    0.000 storage.py:1131(_expired)
     7811    0.030    0.000    0.034    0.000 storage.py:509(__new__)
     7811    0.039    0.000    0.044    0.000 storage.py:581(__init__)
     7764    0.004    0.000    0.007    0.000 storage.py:974(_free_weak_ref)
       32    0.000    0.000    0.011    0.000 synchronize.py:114(_make_name)
        4    0.000    0.000    0.000    0.000 synchronize.py:125(__init__)
        9    0.000    0.000    0.001    0.000 synchronize.py:144(__init__)
       19    0.000    0.000    0.022    0.001 synchronize.py:161(__init__)
        1    0.000    0.000    0.000    0.000 synchronize.py:212(__init__)
        9    0.000    0.000    0.000    0.000 synchronize.py:229(__enter__)
        9    0.000    0.000    0.000    0.000 synchronize.py:232(__exit__)
        1    0.000    0.000    0.000    0.000 synchronize.py:235(_make_methods)
        1    0.000    0.000    0.000    0.000 synchronize.py:270(notify)
        1    0.000    0.000    0.000    0.000 synchronize.py:296(notify_all)
        1    0.000    0.000    0.000    0.000 synchronize.py:323(__init__)
        8    0.000    0.000    0.000    0.000 synchronize.py:327(is_set)
        1    0.000    0.000    0.000    0.000 synchronize.py:334(set)
       32    0.010    0.000    0.023    0.001 synchronize.py:50(__init__)
       32    0.000    0.000    0.000    0.000 synchronize.py:90(_make_methods)
        9    0.000    0.000    0.000    0.000 synchronize.py:94(__enter__)
        9    0.000    0.000    0.000    0.000 synchronize.py:97(__exit__)
       32    0.000    0.000    0.000    0.000 tempfile.py:142(rng)
       32    0.001    0.000    0.011    0.000 tempfile.py:153(__next__)
     8219    0.010    0.000    0.018    0.000 threading.py:1102(_wait_for_tstate_lock)
     4143    0.005    0.000    0.005    0.000 threading.py:1145(ident)
     8219    0.021    0.000    0.041    0.000 threading.py:1169(is_alive)
        8    0.000    0.000    0.000    0.000 threading.py:1183(daemon)
        8    0.000    0.000    0.000    0.000 threading.py:1198(daemon)
        8    0.000    0.000    0.000    0.000 threading.py:1301(_make_invoke_excepthook)
     2068    0.005    0.000    0.006    0.000 threading.py:1430(current_thread)
        1    0.000    0.000    0.000    0.000 threading.py:1478(enumerate)
     2076    0.025    0.000    0.025    0.000 threading.py:236(__init__)
     9917    0.014    0.000    0.021    0.000 threading.py:264(__enter__)
     9917    0.024    0.000    0.027    0.000 threading.py:267(__exit__)
     2053    0.002    0.000    0.003    0.000 threading.py:273(_release_save)
     2053    0.003    0.000    0.006    0.000 threading.py:276(_acquire_restore)
     9903    0.011    0.000    0.018    0.000 threading.py:279(_is_owned)
     2053    0.019    0.000    2.874    0.001 threading.py:288(wait)
     7850    0.039    0.000    0.177    0.000 threading.py:359(notify)
     2067    0.007    0.000    0.034    0.000 threading.py:545(__init__)
     8236    0.003    0.000    0.003    0.000 threading.py:553(is_set)
     2067    0.014    0.000    2.906    0.001 threading.py:589(wait)
        8    0.001    0.000    0.002    0.000 threading.py:827(__init__)
        8    0.000    0.000    0.056    0.007 threading.py:916(start)
     7812    0.014    0.000    0.014    0.000 typing.py:306(inner)
        1    0.000    0.000    0.000    0.000 tz.py:218(utcoffset)
        1    0.000    0.000    0.000    0.000 tz.py:262(_isdst)
       41    0.000    0.000    0.001    0.000 util.py:171(register_after_fork)
       16    0.001    0.000    0.001    0.000 util.py:186(__init__)
       16    0.000    0.000    0.002    0.000 util.py:205(__call__)
        8    0.000    0.000    0.025    0.003 util.py:433(_flush_std_streams)
       16    0.000    0.000    0.000    0.000 util.py:44(sub_debug)
        8    0.000    0.000    0.002    0.000 util.py:461(close_fds)
       80    0.000    0.000    0.000    0.000 util.py:48(debug)
        1    0.000    0.000    0.000    0.000 utils.py:125(__eq__)
        2    0.000    0.000    0.000    0.000 utils.py:139(__getattr__)
        3    0.000    0.000    0.000    0.000 utils.py:152(wrapper_setattr)
        1    0.000    0.000    0.000    0.000 utils.py:156(__init__)
        2    0.000    0.000    0.000    0.000 utils.py:187(disable_on_exception)
     4084    0.015    0.000    3.228    0.001 utils.py:194(inner)
        1    0.000    0.000    0.000    0.000 utils.py:213(__init__)
        2    0.000    0.000    0.000    0.000 utils.py:222(__eq__)
        1    0.000    0.000    0.000    0.000 utils.py:252(_is_utf)
        1    0.000    0.000    0.000    0.000 utils.py:266(_supports_unicode)
        1    0.000    0.000    0.000    0.000 utils.py:282(_screen_shape_wrapper)
        1    0.000    0.000    0.000    0.000 utils.py:333(_screen_shape_linux)
        1    0.000    0.000    0.000    0.000 utils.py:347(<listcomp>)
     2041    0.008    0.000    0.071    0.000 utils.py:374(_text_width)
    53902    0.032    0.000    0.049    0.000 utils.py:375(<genexpr>)
     2041    0.007    0.000    0.084    0.000 utils.py:378(disp_len)
       25    0.000    0.000    0.000    0.000 weakref.py:106(remove)
       41    0.000    0.000    0.001    0.000 weakref.py:165(__setitem__)
       41    0.000    0.000    0.000    0.000 weakref.py:348(__new__)
       41    0.000    0.000    0.000    0.000 weakref.py:353(__init__)
    16210    0.011    0.000    0.011    0.000 {built-in method __new__ of type object at 0x55868a8fe5e0}
        4    0.000    0.000    0.000    0.000 {built-in method _abc._abc_instancecheck}
     8256    0.003    0.000    0.003    0.000 {built-in method _expired}
     7764    0.003    0.000    0.003    0.000 {built-in method _free_weak_ref}
    15622    0.108    0.000    0.108    0.000 {built-in method _hashlib.hmac_new}
     7811    0.224    0.000    0.224    0.000 {built-in method _new_shared_fd_cpu}
     7819    0.490    0.000  138.707    0.018 {built-in method _pickle.loads}
     7811    0.008    0.000    0.008    0.000 {built-in method _socket.CMSG_SPACE}
     7811    0.029    0.000    0.029    0.000 {built-in method _socket.dup}
    31244    0.028    0.000    0.028    0.000 {built-in method _struct.pack}
    31252    0.058    0.000    0.058    0.000 {built-in method _struct.unpack}
     4129    0.007    0.000    0.007    0.000 {built-in method _thread.allocate_lock}
     2076    0.001    0.000    0.001    0.000 {built-in method _thread.get_ident}
        8    0.002    0.000    0.002    0.000 {built-in method _thread.start_new_thread}
       25    0.000    0.000    0.000    0.000 {built-in method _weakref._remove_dead_weakref}
        2    0.000    0.000    0.000    0.000 {built-in method _weakref.proxy}
     2043    0.001    0.000    0.001    0.000 {built-in method builtins.abs}
     8162    0.003    0.000    0.003    0.000 {built-in method builtins.divmod}
        1    0.000    0.000 9238.469 9238.469 {built-in method builtins.exec}
     7855    0.006    0.000    0.006    0.000 {built-in method builtins.getattr}
     4089    0.002    0.000    0.002    0.000 {built-in method builtins.hasattr}
       45    0.000    0.000    0.000    0.000 {built-in method builtins.id}
    96874    0.034    0.000    0.034    0.000 {built-in method builtins.isinstance}
        2    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
232852/232851    0.071    0.000    0.071    0.000 {built-in method builtins.len}
     2105    0.003    0.000    0.003    0.000 {built-in method builtins.max}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.min}
    15818    0.014    0.000    0.620    0.000 {built-in method builtins.next}
     2041    0.015    0.000    0.064    0.000 {built-in method builtins.sum}
        1    0.000    0.000    0.000    0.000 {built-in method fcntl.ioctl}
     2041    0.014    0.000    0.014    0.000 {built-in method fromtimestamp}
     8357    0.010    0.000    0.010    0.000 {built-in method math.ceil}
      256    0.000    0.000    0.000    0.000 {built-in method math.floor}
    15656    0.088    0.000    0.088    0.000 {built-in method posix.close}
        8    0.046    0.006    0.047    0.006 {built-in method posix.fork}
    15622    0.088    0.000    0.088    0.000 {built-in method posix.fstat}
    14196    0.026    0.000    0.026    0.000 {built-in method posix.getpid}
       25    0.000    0.000    0.000    0.000 {built-in method posix.pipe}
    62504   95.406    0.002   95.406    0.002 {built-in method posix.read}
        1    0.000    0.000    0.000    0.000 {built-in method posix.sched_getaffinity}
     7811    0.047    0.000    0.047    0.000 {built-in method posix.urandom}
     4273    0.018    0.000    0.018    0.000 {built-in method posix.waitpid}
        8    0.000    0.000    0.000    0.000 {built-in method posix.waitstatus_to_exitcode}
    31244    0.294    0.000    0.294    0.000 {built-in method posix.write}
     8357    0.008    0.000    0.008    0.000 {built-in method select.poll}
    25585    0.012    0.000    0.012    0.000 {built-in method time.monotonic}
    13913    0.007    0.000    0.007    0.000 {built-in method time.time}
        8    0.000    0.000    0.000    0.000 {built-in method torch._C._error_if_any_worker_fails}
        1    0.000    0.000    0.000    0.000 {built-in method torch._C._remove_worker_pids}
        1    0.000    0.000    0.000    0.000 {built-in method torch._C._set_worker_pids}
     7812    0.461    0.000    0.461    0.000 {built-in method torch._ops.profiler._record_function_enter_new}
     7812    1.012    0.000    1.012    0.000 {built-in method torch._ops.profiler.}
     7812    0.225    0.000    0.225    0.000 {built-in method torch.empty}
    51861    0.017    0.000    0.017    0.000 {built-in method unicodedata.east_asian_width}
     7811    0.017    0.000    0.017    0.000 {function socket.close at 0x7fdf404f64d0}
     7811    0.004    0.000    0.004    0.000 {function socket.detach at 0x7fdf404f6560}
        9    0.000    0.000    0.000    0.000 {method '__enter__' of '_multiprocessing.SemLock' objects}
     9917    0.007    0.000    0.007    0.000 {method '__enter__' of '_thread.lock' objects}
        9    0.000    0.000    0.000    0.000 {method '__exit__' of '_multiprocessing.SemLock' objects}
     2052    0.001    0.000    0.001    0.000 {method '__exit__' of '_thread.RLock' objects}
    25539    0.009    0.000    0.009    0.000 {method '__exit__' of '_thread.lock' objects}
     7812    0.005    0.000    0.005    0.000 {method '__exit__' of 'torch._C.DisableTorchFunctionSubclass' objects}
        1    0.000    0.000    0.000    0.000 {method '_is_mine' of '_multiprocessing.SemLock' objects}
     7811    0.005    0.000    0.005    0.000 {method '_weak_ref' of 'torch._C.StorageBase' objects}
    18247    0.021    0.000    0.021    0.000 {method 'acquire' of '_multiprocessing.SemLock' objects}
     2052    0.002    0.000    0.002    0.000 {method 'acquire' of '_thread.RLock' objects}
    24281    2.858    0.000    2.858    0.000 {method 'acquire' of '_thread.lock' objects}
       27    0.000    0.000    0.000    0.000 {method 'add' of 'set' objects}
    16063    0.005    0.000    0.005    0.000 {method 'append' of 'collections.deque' objects}
     7843    0.003    0.000    0.003    0.000 {method 'append' of 'list' objects}
        8    0.000    0.000    0.000    0.000 {method 'clear' of 'collections.deque' objects}
     8357    0.006    0.000    0.006    0.000 {method 'clear' of 'dict' objects}
     7811    0.235    0.000    0.235    0.000 {method 'connect' of '_socket.socket' objects}
     7819    0.013    0.000    0.013    0.000 {method 'copy' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 {method 'difference' of 'set' objects}
    15622    0.049    0.000    0.049    0.000 {method 'digest' of '_hashlib.HMAC' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       24    0.000    0.000    0.000    0.000 {method 'discard' of 'set' objects}
     7811    0.026    0.000    0.026    0.000 {method 'dump' of '_pickle.Pickler' objects}
        2    0.000    0.000    0.000    0.000 {method 'encode' of 'str' objects}
     7811    0.008    0.000    0.008    0.000 {method 'frombytes' of 'array.array' objects}
     2043    0.001    0.000    0.001    0.000 {method 'get' of '_contextvars.ContextVar' objects}
     7819    0.006    0.000    0.006    0.000 {method 'get' of 'dict' objects}
     7811    0.006    0.000    0.006    0.000 {method 'getbuffer' of '_io.BytesIO' objects}
    62504    0.022    0.000    0.022    0.000 {method 'getvalue' of '_io.BytesIO' objects}
        1    0.000    0.000    0.000    0.000 {method 'item' of 'torch._C.TensorBase' objects}
     2107    0.001    0.000    0.001    0.000 {method 'items' of 'dict' objects}
       40    0.000    0.000    0.000    0.000 {method 'join' of 'str' objects}
        4    0.000    0.000    0.000    0.000 {method 'keys' of 'dict' objects}
     8357 9090.788    1.088 9090.788    1.088 {method 'poll' of 'select.poll' objects}
     5739    0.005    0.000    0.005    0.000 {method 'pop' of 'dict' objects}
        2    0.000    0.000    0.000    0.000 {method 'pop' of 'list' objects}
      256    0.000    0.000    0.000    0.000 {method 'random' of '_random.Random' objects}
        1    0.000    0.000    0.000    0.000 {method 'random_' of 'torch._C.TensorBase' objects}
     7811   37.854    0.005   37.854    0.005 {method 'recvmsg' of '_socket.socket' objects}
     8357    0.007    0.000    0.007    0.000 {method 'register' of 'select.poll' objects}
    18213    0.014    0.000    0.014    0.000 {method 'release' of '_multiprocessing.SemLock' objects}
     2052    0.001    0.000    0.001    0.000 {method 'release' of '_thread.RLock' objects}
     9893    0.120    0.000    0.120    0.000 {method 'release' of '_thread.lock' objects}
     7840    0.003    0.000    0.003    0.000 {method 'remove' of 'collections.deque' objects}
        3    0.000    0.000    0.000    0.000 {method 'remove' of 'set' objects}
     7871    0.008    0.000    0.008    0.000 {method 'rpartition' of 'str' objects}
     7811    0.136    0.000    0.136    0.000 {method 'set_' of 'torch._C.TensorBase' objects}
     7811    0.025    0.000    0.025    0.000 {method 'setblocking' of '_socket.socket' objects}
    15622    0.015    0.000    0.015    0.000 {method 'startswith' of 'str' objects}
     2041    0.005    0.000    0.005    0.000 {method 'sub' of 're.Pattern' objects}
     7811    0.108    0.000    0.108    0.000 {method 'update' of 'dict' objects}
        2    0.000    0.000    0.000    0.000 {method 'values' of 'dict' objects}
    62504    0.037    0.000    0.037    0.000 {method 'write' of '_io.BytesIO' objects}
     2043    0.001    0.000    0.001    0.000 {method 'write' of '_io.StringIO' objects}

Switched to Ray which works much faster on first shot without any optimization but just taking the example code from Ray doc. If speed is the priority, I would say Ray is the way to go.

Hi @zhh210 , can you please provide the exact code you substituted for Ray ? And how much the difference is between the two approaches?

Thanks

Id also like to see the Ray code implementation. Would love to see if I could get a speed up with my data loading as well.

Ray’s implementation is quite straightforward and well documented here. In my usecase, I just use the default one-liner with no tweak at all and it runs much faster than heavily tuned HF dataset:

import ray
import torch
ds = ray.data.read_parquet(['s3://' + i for i in s3.glob('s3://my_s3_bucket/*')])
dl_train = ds.iter_torch_batches(batch_size=1600, dtypes=torch.float32)
1 Like