If I have multiple serialized datasets then is it a good practice to use multiprocessing
along with datasets.load_dataset()
?
Something like so:
import multiprocessing
import datasets
import gcsfs
BUCKET_NAME = "my-bucket"
GCS_FS = gcsfs.GCSFileSystem()
def load_ds(ds_path):
return datasets.load_from_disk(ds_path, fs=GCS_FS)
ds_dirs = GCS_FS.listdir(f"{BUCKET_NAME}/saved_datasets")
ds_dirs = list(
{
f"{dd['name']}"
for dd in ds_dirs
if "tokenized" in dd["name"]
}
)
with multiprocessing.Pool() as pool:
ds_list = pool.starmap_async(load_ds, ds_dirs).get()
ds_list = [ds for ds in ds_list]
train_ds = datasets.concatenate_datasets(ds_list)
print(train_ds)
Running this snippet does not execute as expected though.