Custom 20GB Arrow dataset very slow to train

Hi everyone,
This is my first post in the forum, so please let me know if there is any issue with the post or if it should be in another category :slight_smile:

I have a dataset of 300M events, where each event belongs to one of 6K users. Each event has a timestamp and a few fields of numeric metadata.
Each user has a binary label, and I want to train a classifier on this data.

My approach is to split each user’s events to contiguous chunks of fixed size, tokenize them using a custom function,
and train a sequence classification model on each chunk.

Sample of a user’s data (in JSONL format):

{"Datetime": "2023-05-04 13:19:55+03:00", "URL": 1308939, "Domain_Name": 2362255, "Domain_cls1": 624, "Domain_cls2": 0, "Domain_cls3": 0, "Domain_cls4": 0, "Target": 0}
{"Datetime": "2023-05-04 13:19:56+03:00", "URL": 819346, "Domain_Name": 69088, "Domain_cls1": 360, "Domain_cls2": 0, "Domain_cls3": 0, "Domain_cls4": 0, "Target": 0}
{"Datetime": "2023-05-04 13:20:01+03:00", "URL": 1302705, "Domain_Name": 1129695, "Domain_cls1": 552, "Domain_cls2": 0, "Domain_cls3": 0, "Domain_cls4": 0, "Target": 0}

For efficient training, I converted the JSONL files to an Arrow dataset using load_dataset and concatenate_datasets:

import pathlib
from datasets import load_dataset, concatenate_datasets

# Load the JSONL files into a Hugging Face dataset.
# Each file becomes a dataset of individual visits.
datasets_list = []
for i, fp in enumerate(file_paths):
    # Each line is a JSON record.
    ds = load_dataset("json", data_files=fp, split="train")
    # Add a field with the user id (derived from filename)
    user_id = int(pathlib.Path(fp).stem)
    ds = ds.add_column("user_id", [user_id]*len(ds))
    datasets_list.append(ds)

# Concatenate all per-file datasets into one.
full_ds = concatenate_datasets(datasets_list)

# Save the full dataset to disk in Arrow format.
full_ds.save_to_disk("C:\\Users\\ilayl\\Downloads\\arrow_dataset")

The total size of the Arrow dataset is roughly 20GB.

I trained a model on this data using below code, but the final trainer.train() step is incredibly slow, roughly 10 seconds per step (see profiler results below).

Seems like the issue is a slow .select(indices) call on the Arrow dataset, even though (if I understand correctly) the dataset is memory-mapped, and queries like this should be efficient.
Do you know if there is a way to speed up the dataset queries, or if there is existing HuggingFace functionality that solves this problem more efficiently?

I am running on Windows 11 machine with 64GB memory, with transformers version 4.49.0 and datasets version 3.4.0.

import math
import random

import numpy as np
import torch
from datasets import load_from_disk
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import BertConfig, BertForSequenceClassification, Trainer, TrainingArguments

def tokenize_event(event):
    # placeholder tokenization
    return [event["Domain_cls1"]]


PAD_TOKEN_ID = 1000
VOCAB_SIZE = PAD_TOKEN_ID + 1
TOKENS_PER_LINE = 1


class ArrowBrowsingDataset(Dataset):
    def __init__(self, arrow_dataset, chunk_size=512, batch_size=10000):
        """
        Build metadata for each user by scanning through the dataset in slices.
        Only the "user_id" and "Target" columns are loaded per slice.
        """
        self.arrow_ds = arrow_dataset
        self.chunk_size = chunk_size

        # Mapping: user_id -> (start_index, count, label)
        self.user_boundaries = {}
        # List of tuples: (user_id, chunk_idx) for later indexing
        self.index_mapping = []

        current_user = None
        current_start = 0
        current_count = 0
        current_label = None
        total_index = 0

        dataset_length = len(self.arrow_ds)
        # Iterate over dataset slices using indices
        for start in tqdm(range(0, dataset_length, batch_size)):
            end = min(start + batch_size, dataset_length)
            # Slicing the dataset returns a dictionary of lists
            batch = self.arrow_ds[start:end]
            user_ids = batch["user_id"]
            targets = batch["Target"]
            for uid, target in zip(user_ids, targets):
                if current_user is None:
                    # Initialize the first user block.
                    current_user = uid
                    current_start = total_index
                    current_count = 1
                    current_label = target
                elif uid == current_user:
                    # Continue the current user's block.
                    current_count += 1
                else:
                    # Save the completed user's boundaries.
                    self.user_boundaries[current_user] = (current_start, current_count, current_label)
                    num_chunks = math.ceil(current_count / self.chunk_size)
                    for chunk_idx in range(num_chunks):
                        self.index_mapping.append((current_user, chunk_idx))
                    # Start a new block for the new user.
                    current_user = uid
                    current_start = total_index
                    current_count = 1
                    current_label = target
                total_index += 1

        # Save the last user's boundaries.
        if current_user is not None:
            self.user_boundaries[current_user] = (current_start, current_count, current_label)
            num_chunks = math.ceil(current_count / self.chunk_size)
            for chunk_idx in range(num_chunks):
                self.index_mapping.append((current_user, chunk_idx))

        # Shuffle the index mapping if desired.
        random.shuffle(self.index_mapping)

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, idx):
        """
        Returns a contiguous chunk of events for a single user.
        It loads only the necessary rows from the dataset.
        """
        user_id, chunk_idx = self.index_mapping[idx]
        start, count, label = self.user_boundaries[user_id]
        offset = start + chunk_idx * self.chunk_size
        length = min(self.chunk_size, (start + count) - offset)
        indices = list(range(offset, offset + length))

        # Use select to efficiently retrieve only the needed rows.
        events = self.arrow_ds.select(indices)

        # Process events by tokenizing them (assumes tokenize_event is defined).
        chunk_tokens = []
        for event in events:
            tokens = tokenize_event(event)
            chunk_tokens.extend(tokens)

        return {"input_ids": chunk_tokens, "labels": label}


def collate_fn(batch):
    """
    Pads each sequence in the batch to the maximum length in the batch,
    and creates an attention mask.
    """
    input_ids_list = [torch.tensor(sample["input_ids"], dtype=torch.long) for sample in batch]
    # Pad sequences using the PAD_TOKEN_ID.
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=PAD_TOKEN_ID)
    # Create attention mask: 1 for non-pad tokens, 0 for pad tokens.
    attention_mask = (input_ids != PAD_TOKEN_ID).long()
    labels = torch.tensor([sample["labels"] for sample in batch], dtype=torch.long)

    # Debug: Check maximum token id in the batch
    max_token = input_ids.max().item()
    if max_token >= VOCAB_SIZE:
        raise ValueError(f"Found token id {max_token} which is >= VOCAB_SIZE ({VOCAB_SIZE}). Adjust your VOCAB_SIZE or token offsets accordingly.")

    # Debug: Print unique labels
    if torch.min(labels) < 0 or torch.max(labels) >= 2:
        print("Unexpected label values in batch:", labels.tolist())

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


# Load the Arrow dataset.
arrow_ds = load_from_disk("C:\\Users\\ilayl\\Downloads\\arrow_dataset")

# perform a random train/test split based on user ids.
user_ids = list(set(arrow_ds["user_id"]))
random.seed(42)
random.shuffle(user_ids)
split = int(0.9 * len(user_ids))
train_user_ids = set(user_ids[:split])
test_user_ids = set(user_ids[split:])


# Filter the arrow dataset to create train and test subsets (if desired).
train_ds = arrow_ds.filter(lambda ex: ex["user_id"] in train_user_ids)
test_ds = arrow_ds.filter(lambda ex: ex["user_id"] in test_user_ids)


# Build PyTorch Datasets from the Arrow subsets.
train_dataset = ArrowBrowsingDataset(train_ds, chunk_size=512, batch_size=100000)
test_dataset = ArrowBrowsingDataset(test_ds, chunk_size=512, batch_size=100000)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    accuracy = (preds == labels).mean()
    return {"accuracy": accuracy}


# Define a small transformer configuration and model from scratch.
config = BertConfig(
    vocab_size=VOCAB_SIZE,
    hidden_size=256,              # smaller hidden size for demonstration
    num_hidden_layers=4,          # fewer layers
    num_attention_heads=4,
    intermediate_size=512,
    max_position_embeddings=512,
    pad_token_id=PAD_TOKEN_ID,
    num_labels=2,
)
model = BertForSequenceClassification(config)


# Define training arguments.
training_args = TrainingArguments(
    output_dir="./output",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    max_steps=50,
    save_steps=1000,
    save_total_limit=2,
    logging_strategy="steps",
    logging_steps=100,
    # eval_strategy="steps",
    # eval_steps=100,
)

# Initialize the Trainer.
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)


# Start training.
trainer.train()

Below is the result of the Python profiler, after running max_steps=50:

         562835963 function calls (561799711 primitive calls) in 495.093 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    19992  303.494    0.015  363.394    0.018 table.py:106(__init__)
533283756/533283732   51.701    0.000   51.704    0.000 {built-in method builtins.len}
      408   36.598    0.090  412.889    1.012 arrow_dataset.py:581(update_metadata_with_features)
      408   33.987    0.083   33.988    0.083 arrow_dataset.py:1410(__del__)
   1409/3    9.209    0.007    0.000    0.000 {method 'acquire' of '_thread.lock' objects}
   207759    7.606    0.000    7.992    0.000 formatting.py:144(extract_row)
    19992    6.878    0.000    7.880    0.000 fromnumeric.py:41(_wrapit)
        1    4.874    4.874  222.695  222.695 trainer.py:3815(save_model)
    193/0    4.232    0.022    0.000          threading.py:323(wait)
      355    2.837    0.008    4.486    0.013 {built-in method select.select}
  2077590    2.680    0.000    5.511    0.000 py_utils.py:324(zip_dict)
   207759    2.612    0.000   19.255    0.000 formatting.py:621(format_table)
   415926    2.550    0.000    3.946    0.000 table.py:130(fast_slice)
      355    2.301    0.006   27.975    0.079 selectors.py:313(_select)
       50    1.983    0.040  580.358   11.607 trainer.py:5175(get_batch_samples)
   207759    1.595    0.000    7.137    0.000 features.py:2078(decode_example)
  5609493    1.484    0.000    1.495    0.000 py_utils.py:328(<genexpr>)
   831444    1.320    0.000    1.397    0.000 table.py:79(_interpolation_search)
  2077590    1.119    0.000    1.336    0.000 py_utils.py:298(unique_values)
   207759    1.090    0.000    5.392    0.000 formatting.py:51(_query_table_with_indices_mapping)
    19992    0.961    0.000    0.961    0.000 {method 'cumsum' of 'numpy.ndarray' objects}
   207759    0.676    0.000   26.815    0.000 arrow_dataset.py:2747(_getitem)
216327/4896    0.564    0.000    1.579    0.000 pickle.py:532(save)
   415926    0.527    0.000    0.657    0.000 formatting.py:399(__init__)
    19176    0.509    0.000  222.299    0.012 table.py:1127(replace_schema_metadata)
260501/122408    0.450    0.000    0.966    0.000 copy.py:118(deepcopy)
   208167    0.431    0.000   27.249    0.000 arrow_dataset.py:2372(__iter__)
    39872    0.427    0.000    0.796    0.000 ipkernel.py:775(_clean_thread_parent_frames)
   207759    0.420    0.000    6.008    0.000 formatting.py:578(query_table)
      613    0.357    0.001    0.403    0.001 decorator.py:232(fun)
        1    0.354    0.354    0.355    0.355 {built-in method gc.collect}
   207759    0.351    0.000    0.351    0.000 formatting.py:665(<genexpr>)
   207759    0.348    0.000    0.386    0.000 formatting.py:127(_unnest)
    19992    0.347    0.000    8.236    0.000 fromnumeric.py:51(_wrapfunc)
   415926    0.336    0.000    0.994    0.000 formatting.py:452(__init__)
   207759    0.304    0.000   15.571    0.000 formatting.py:456(format_row)
      408    0.287    0.001  449.941    1.103 1420505965.py:70(__getitem__)
   208167    0.282    0.000    0.848    0.000 __init__.py:123(get_formatter)
   207759    0.266    0.000    2.393    0.000 formatting.py:81(_query_table)
       50    0.250    0.005    0.250    0.005 {method 'run_backward' of 'torch._C._EngineBase' objects}
2288713/2278925    0.237    0.000    0.249    0.000 {built-in method builtins.isinstance}
  1904228    0.225    0.000    0.229    0.000 {method 'add' of 'set' objects}
   207759    0.194    0.000    0.316    0.000 pickle.py:749(save_long)
     4157    0.189    0.000    0.189    0.000 {method 'to' of 'torch._C.TensorBase' objects}
   219499    0.175    0.000    0.175    0.000 threading.py:1198(ident)
216327/4896    0.167    0.000    1.585    0.000 _dill.py:367(save)
216327/4896    0.156    0.000    1.593    0.000 _dill.py:31(save)
    19936    0.152    0.000    0.172    0.000 threading.py:1535(enumerate)
       50    0.149    0.003    0.151    0.003 modeling_attn_mask_utils.py:429(_prepare_4d_attention_mask_for_sdpa)
   207759    0.138    0.000    7.275    0.000 formatting.py:224(decode_row)
   625317    0.132    0.000    0.132    0.000 table.py:390(num_rows)
63240/1224    0.122    0.000    0.331    0.000 py_utils.py:206(_asdict_inner)
   207759    0.120    0.000   15.690    0.000 formatting.py:409(__call__)
    19584    0.111    0.000  218.820    0.011 table.py:1011(__init__)
    95096    0.103    0.000    0.120    0.000 copy.py:231(_keep_alive)
   775456    0.101    0.000    0.101    0.000 {method 'get' of 'dict' objects}
     7403    0.100    0.000    0.100    0.000 {method 'item' of 'torch._C.TensorBase' objects}
   221223    0.100    0.000    0.134    0.000 pickle.py:211(commit_frame)
      816    0.096    0.000    1.514    0.002 pickle.py:956(_batch_appends)
   207759    0.090    0.000    0.105    0.000 formatting.py:549(_check_valid_index_key)
47355/39993    0.090    0.000    0.392    0.000 copy.py:217(_deepcopy_dict)
   511432    0.087    0.000    0.087    0.000 {method 'items' of 'dict' objects}
      408    0.082    0.000  365.308    0.895 table.py:1594(replace_schema_metadata)
   415926    0.076    0.000    0.076    0.000 formatting.py:218(__init__)
     1300    0.073    0.000    0.073    0.000 {built-in method torch._C._nn.linear}
   207759    0.069    0.000    0.083    0.000 formatting.py:568(key_to_query_type)
   760513    0.069    0.000    0.069    0.000 {built-in method builtins.id}
19994/19586    0.068    0.000    0.319    0.000 copy.py:200(_deepcopy_tuple)
133943/18224    0.066    0.000    0.071    0.000 module.py:2775(named_modules)
   228975    0.065    0.000    0.099    0.000 pickle.py:235(write)
    19992    0.064    0.000    8.345    0.000 fromnumeric.py:2904(cumsum)
   208575    0.062    0.000    0.062    0.000 __init__.py:115(get_format_type_from_alias)
     2856    0.061    0.000    0.061    0.000 table.py:420(column_names)
   350369    0.058    0.000    0.059    0.000 {built-in method builtins.getattr}
   207759    0.056    0.000    0.056    0.000 3971746712.py:1(tokenize_event)
     2649    0.055    0.000    0.154    0.000 inspect.py:2397(_signature_from_function)
   415926    0.055    0.000    0.055    0.000 formatting.py:235(__init__)
    19584    0.054    0.000    0.054    0.000 {built-in method nt._getfullpathname}
   208049    0.050    0.000    0.050    0.000 {method 'extend' of 'list' objects}
16320/1632    0.044    0.000    0.054    0.000 features.py:1235(get_nested_type)
   224514    0.043    0.000    0.043    0.000 {built-in method _struct.pack}
   216327    0.041    0.000    0.041    0.000 _dill.py:316(get)
      146    0.041    0.000    0.041    0.000 {method 'copy_' of 'torch._C.StorageBase' objects}
 9792/816    0.039    0.000    0.269    0.000 copy.py:247(_reconstruct)
    19992    0.037    0.000  363.432    0.018 table.py:166(__init__)
   238359    0.037    0.000    0.037    0.000 {method 'write' of '_io.BytesIO' objects}
    23216    0.036    0.000    0.036    0.000 {built-in method _abc._abc_instancecheck}
     1416    0.035    0.000    0.071    0.000 {built-in method builtins.all}
2850/2649    0.033    0.000    0.203    0.000 inspect.py:2501(_signature_from_callable)
    12651    0.033    0.000    0.055    0.000 inspect.py:2743(__init__)
     1224    0.032    0.000    0.240    0.000 fingerprint.py:307(format_kwargs_for_fingerprint)
     6760    0.031    0.000    0.045    0.000 <frozen importlib._bootstrap_external>:96(_path_join)
    63648    0.031    0.000    0.042    0.000 dataclasses.py:1301(is_dataclass)
      408    0.031    0.000  416.417    1.021 arrow_dataset.py:631(__init__)
    15248    0.030    0.000    0.048    0.000 dataclasses.py:1278(fields)
    19584    0.029    0.000    0.121    0.000 table.py:1044(_append_replay)
   208167    0.028    0.000    0.028    0.000 arrow_dataset.py:3926(<genexpr>)
  816/408    0.027    0.000  419.922    1.029 arrow_dataset.py:541(wrapper)
    19584    0.027    0.000    0.027    0.000 {built-in method nt._path_normpath}
      816    0.027    0.000    0.271    0.000 features.py:1809(arrow_schema)
    63648    0.026    0.000    0.070    0.000 py_utils.py:202(_is_dataclass_instance)
   226119    0.026    0.000    0.026    0.000 {method 'tell' of '_io.BytesIO' objects}
      408    0.025    0.000    0.248    0.001 features.py:1820(from_arrow_schema)
   216327    0.025    0.000    0.025    0.000 pickle.py:603(persistent_id)
    11398    0.024    0.000    0.092    0.000 module.py:2588(_named_members)
      500    0.022    0.000    0.022    0.000 {built-in method torch.dropout}
 4080/408    0.022    0.000    0.052    0.000 features.py:1444(generate_from_dict)
   122063    0.022    0.000    0.022    0.000 {method 'append' of 'list' objects}
    19584    0.022    0.000    0.103    0.000 <frozen ntpath>:581(abspath)
    22032    0.022    0.000    0.031    0.000 features.py:1659(require_decoding)
    19992    0.021    0.000    0.021    0.000 {method 'as_arrays' of 'numpy._core._multiarray_umath._array_converter' objects}
  624/559    0.020    0.000    0.023    0.000 socket.py:635(send)
      633    0.019    0.000    0.019    0.000 {built-in method torch.tensor}
     4896    0.018    0.000    1.651    0.000 pickle.py:473(dump)
     2856    0.018    0.000    0.050    0.000 features.py:1778(__init__)
    55870    0.017    0.000    0.017    0.000 {method 'values' of 'dict' objects}
      408    0.017    0.000  148.599    0.364 table.py:1300(__init__)
    23216    0.016    0.000    0.052    0.000 <frozen abc>:117(__instancecheck__)
19589/19587    0.016    0.000    0.029    0.000 copy.py:191(_deepcopy_list)
   158059    0.016    0.000    0.016    0.000 copy.py:172(_deepcopy_atomic)
     4896    0.016    0.000    0.033    0.000 _dill.py:351(__init__)
     2850    0.016    0.000    0.025    0.000 inspect.py:3029(__init__)
    80355    0.015    0.000    0.015    0.000 {method 'keys' of 'dict' objects}
     7344    0.015    0.000    0.026    0.000 pickle.py:869(save_str)
      450    0.014    0.000    0.014    0.000 {built-in method torch.layer_norm}
     4896    0.014    0.000    1.765    0.000 fingerprint.py:190(update)
     4896    0.014    0.000    0.015    0.000 pickle.py:406(__init__)
    67149    0.014    0.000    0.014    0.000 dataclasses.py:1293(<genexpr>)
   106100    0.014    0.000    0.014    0.000 {built-in method builtins.hasattr}
     1354    0.014    0.000    0.014    0.000 {built-in method nt.stat}
      100    0.014    0.000    0.024    0.000 adamw.py:132(_init_group)
        1    0.014    0.014    0.014    0.014 {built-in method torch._C._cuda_emptyCache}
       51    0.014    0.000  415.006    8.137 fetch.py:47(fetch)
  4150/50    0.014    0.000    0.406    0.008 module.py:1743(_call_impl)
     1432    0.013    0.000    0.014    0.000 encoder.py:205(iterencode)
 1224/408    0.013    0.000  419.874    1.029 fingerprint.py:410(wrapper)
     1352    0.013    0.000    0.075    0.000 <frozen importlib._bootstrap_external>:1597(find_spec)
     4588    0.013    0.000    0.031    0.000 module.py:1932(__setattr__)
      816    0.012    0.000    0.315    0.000 info.py:296(copy)
    14761    0.012    0.000    0.017    0.000 enum.py:720(__call__)
     4896    0.012    0.000    1.720    0.000 _dill.py:101(dump)
      408    0.011    0.000  412.554    1.011 arrow_dataset.py:3939(_select_contiguous)
     9792    0.011    0.000    0.013    0.000 {method '__reduce_ex__' of 'object' objects}
    19992    0.011    0.000    0.011    0.000 fromnumeric.py:2900(_cumsum_dispatcher)
  4437/51    0.011    0.000    0.053    0.001 module.py:2824(train)
      200    0.010    0.000    0.010    0.000 {built-in method torch._C._nn.scaled_dot_product_attention}
    19992    0.009    0.000    0.009    0.000 {method 'wrap' of 'numpy._core._multiarray_umath._array_converter' objects}
       50    0.009    0.000    0.009    0.000 {built-in method torch._foreach_norm}
     1346    0.009    0.000    0.009    0.000 decoder.py:344(raw_decode)
      150    0.009    0.000    0.009    0.000 {built-in method torch.embedding}
     2649    0.009    0.000    0.012    0.000 inspect.py:176(get_annotations)
      355    0.009    0.000   35.224    0.099 base_events.py:1922(_run_once)
       50    0.008    0.000    0.008    0.000 {built-in method torch.isinf}
     4896    0.008    0.000    0.017    0.000 logger.py:127(trace_setup)
     8492    0.008    0.000    0.008    0.000 module.py:1915(__getattr__)
    14688    0.008    0.000    0.008    0.000 {method 'update' of 'xxhash.xxh64' objects}
 4080/408    0.008    0.000    0.010    0.000 features.py:2193(recursive_reorder)
     4896    0.008    0.000    0.016    0.000 fingerprint.py:178(hash_bytes)
      100    0.008    0.000    0.060    0.001 adamw.py:486(_multi_tensor_adamw)
     4896    0.008    0.000    1.751    0.000 fingerprint.py:186(hash)
     7300    0.008    0.000    0.015    0.000 optimizer.py:101(_get_value)
      408    0.008    0.000    0.020    0.000 arrow_dataset.py:2477(set_format)
       50    0.008    0.000    0.064    0.001 modeling_utils.py:1141(num_parameters)
     4896    0.008    0.000    1.676    0.000 _dill.py:418(dump)
      816    0.007    0.000    1.777    0.002 fingerprint.py:227(update_fingerprint)
      100    0.007    0.000    0.007    0.000 {built-in method torch._foreach_sqrt}
    15300    0.007    0.000    0.009    0.000 inspect.py:3076(<genexpr>)
20400/1632    0.007    0.000    0.247    0.000 copy.py:252(<genexpr>)
  4150/50    0.007    0.000    0.406    0.008 module.py:1735(_wrapped_call_impl)
     2649    0.007    0.000    0.010    0.000 inspect.py:754(unwrap)
       50    0.007    0.000    0.007    0.000 {method 'ne' of 'torch._C.TensorBase' objects}
      200    0.006    0.000    0.006    0.000 {built-in method torch._foreach_add_}
      408    0.006    0.000  416.908    1.022 arrow_dataset.py:3850(select)
       51    0.006    0.000    0.039    0.001 module.py:2887(zero_grad)
    26457    0.006    0.000    0.009    0.000 _tensor.py:1166(__hash__)
     7960    0.006    0.000    0.013    0.000 traitlets.py:676(__get__)
     4896    0.006    0.000    1.728    0.000 _dill.py:106(dumps)
        1    0.006    0.006    0.006    0.006 {method 'tolist' of 'torch._C.TensorBase' objects}
     9342    0.006    0.000    0.007    0.000 module.py:2728(named_children)
      200    0.006    0.000    0.006    0.000 {built-in method torch._C._group_tensors_by_device_and_dtype}
     4896    0.006    0.000    0.027    0.000 pickle.py:206(end_framing)
     7960    0.006    0.000    0.007    0.000 traitlets.py:629(get)
    14761    0.006    0.000    0.006    0.000 enum.py:1123(__new__)
     2649    0.006    0.000    0.214    0.000 inspect.py:3343(signature)
    20041    0.005    0.000    0.005    0.000 {method '__exit__' of '_thread.RLock' objects}
    13402    0.005    0.000    0.007    0.000 module.py:2658(<lambda>)
      200    0.005    0.000    0.033    0.000 modeling_bert.py:465(forward)
     8996    0.005    0.000    0.012    0.000 module.py:2719(children)
      101    0.005    0.000    0.005    0.000 {built-in method torch._ops.profiler._record_function_enter_new}
      200    0.005    0.000    0.033    0.000 modeling_bert.py:551(forward)
     3672    0.005    0.000    0.009    0.000 features.py:118(string_to_arrow)
     2649    0.005    0.000    0.208    0.000 inspect.py:3081(from_callable)
      613    0.005    0.000    0.007    0.000 inspect.py:3133(_bind)
    15141    0.005    0.000    0.005    0.000 {method 'startswith' of 'str' objects}
      146    0.005    0.000    0.005    0.000 {built-in method torch.zeros_like}
     5488    0.005    0.000    0.008    0.000 serialization.py:1136(persistent_id)
       51    0.005    0.000    0.005    0.000 {built-in method torch._C._nn.pad_sequence}
     3277    0.005    0.000    0.005    0.000 {built-in method builtins.sorted}
    17408    0.005    0.000    0.005    0.000 {method 'encode' of 'str' objects}
      512    0.004    0.000    0.005    0.000 attrsettr.py:66(_get_attr_opt)
      200    0.004    0.000    0.004    0.000 {built-in method torch._C._nn.gelu}
     7550    0.004    0.000    0.066    0.000 module.py:2608(parameters)
      100    0.004    0.000    0.004    0.000 {built-in method torch._foreach_lerp_}
       50    0.004    0.000    0.004    0.000 {method 'sum' of 'torch._C.TensorBase' objects}
     1346    0.004    0.000    0.018    0.000 decoder.py:333(decode)
      200    0.004    0.000    0.067    0.000 modeling_bert.py:364(forward)
       50    0.004    0.000    0.389    0.008 modeling_bert.py:1001(forward)
    11398    0.004    0.000    0.096    0.000 module.py:2633(named_parameters)
     8050    0.004    0.000    0.005    0.000 __init__.py:367(is_compiling)
     4896    0.004    0.000    0.006    0.000 __init__.py:1964(isEnabledFor)
       73    0.004    0.000    0.004    0.000 {method 'tobytes' of 'numpy.ndarray' objects}
     2704    0.004    0.000    0.004    0.000 {method 'match' of 're.Pattern' objects}
      150    0.004    0.000    0.004    0.000 {built-in method torch._foreach_mul_}
       50    0.004    0.000    0.726    0.015 trainer.py:3668(training_step)
       51    0.004    0.000    0.027    0.001 885070178.py:1(collate_fn)
     1346    0.004    0.000    0.023    0.000 __init__.py:299(loads)
      355    0.004    0.000   27.985    0.079 selectors.py:319(select)
1 Like

I think lhonestq will be able to help you with the details, but for now, it seems that Arrow is not recommended for processing large datasets.

The following are general improvement suggestions from Hugging Chat.


To address the slow .select(indices) call on the Arrow dataset, here are the step-by-step optimizations:

Step 1: Replace List-Based Indices with Slice Objects

Why: Using slices for contiguous chunks enhances performance because Arrow can process ranges more efficiently than individual indices.

How:

  • Modify the index_mapping to store start and end indices instead of lists.
  • Use slices in .select() to retrieve contiguous data chunks.

Code Adjustment:

# Modify __getitem__ to use slices
def __getitem__(self, idx):
    user_id, chunk_idx = self.index_mapping[idx]
    start_idx, end_idx = self.user_boundaries[user_id]
    offset_start = start_idx + chunk_idx * self.chunk_size
    offset_end = min(offset_start + self.chunk_size, start_idx + end_idx)
    events = self.arrow_ds.select(slice(offset_start, offset_end))

Step 2: Enable Dataset Caching

Why: Ensures data stays in memory, reducing disk I/O and speeding up repeated access.

How:

  • Enable caching when loading the dataset.
  • Use .set_format to read data as NumPy arrays for faster access.

Code Adjustment:

# Load the Arrow dataset with caching enabled
arrow_ds = load_from_disk("path/to/arrow_dataset")
arrow_ds = arrow_ds.cache_files()
arrow_ds = arrow_ds.with_format("numpy")

Step 3: Optimize Index Creation

Why: Batch processing during index creation reduces the number of operations and leverages Arrow’s efficiency for contiguous data.

How:

  • Increase the batch size during dataset initialization to process larger chunks at once.

Code Adjustment:

# Increase batch_size for more efficient indexing
train_dataset = ArrowBrowsingDataset(train_ds, chunk_size=512, batch_size=100000)

Step 4: Adjust Data Shuffling Strategy

Why: Shuffling can disrupt contiguous data access, leading to slower performance.

How:

  • Avoid random shuffling of the index_mapping to maintain data locality.
  • If randomness is needed, consider alternative methods or shuffle after indexing.

Code Adjustment:

# Remove random.shuffle if it disrupts data locality
# random.shuffle(self.index_mapping)

Step 5: Precompute Tokens During Initialization

Why: Reduces the overhead of tokenization during training, speeding up data retrieval.

How:

  • Precompute tokens for each event during dataset initialization.
  • Store tokenized sequences in the dataset for faster access.

Code Adjustment:

# Precompute tokens in initialize
class ArrowBrowsingDataset(Dataset):
    def __init__(self, arrow_dataset, chunk_size=512, batch_size=100000):
        # ... existing code ...
        events = self.arrow_ds.select(slice(start, end))
        # Precompute tokens
        tokenized_events = [tokenize_event(event) for event in events]
        self.tokenized_data.append(tokenized_events)

Step 6: Parallelize Data Loading

Why: Utilizes multiple CPU cores to load data faster, improving overall efficiency.

How:

  • Increase the num_workers parameter in the DataLoader.
  • Use pinned memory to speed up data transfer to the GPU.

Code Adjustment:

# Adjust DataLoader settings
trainer = Trainer(
    # ... other parameters ...
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_batch_size=8,
    num_workers=4,
    pin_memory=True
)

Final Code Implementation

# Modified __getitem__ using slices
class ArrowBrowsingDataset(Dataset):
    def __getitem__(self, idx):
        user_id, chunk_idx = self.index_mapping[idx]
        start, count, label = self.user_boundaries[user_id]
        offset_start = start + chunk_idx * self.chunk_size
        offset_end = min(offset_start + self.chunk_size, start + count)
        events = self.arrow_ds.select(slice(offset_start, offset_end))
        # Precomputed tokens for faster access
        return {"input_ids": self.tokenized_data[user_id][chunk_idx], "labels": label}

# Load dataset with optimizations
arrow_ds = load_from_disk("path/to/arrow_dataset")
arrow_ds = arrow_ds.cache_files().with_format("numpy")

# Create datasets with optimized parameters
train_dataset = ArrowBrowsingDataset(train_ds, chunk_size=512, batch_size=100000)
test_dataset = ArrowBrowsingDataset(test_ds, chunk_size=512, batch_size=100000)

# Initialize Trainer with parallel loading
training_args = TrainingArguments(
    # ... other parameters ...
    per_device_train_batch_size=8,
    num_workers=4,
    pin_memory=True,
)

Summary of Optimizations

  1. Slice-Based Selection: Replaced list indices with slices for faster contiguous data retrieval.
  2. Caching and NumPy Format: Enabled dataset caching and NumPy access for efficient in-memory operations.
  3. Efficient Indexing: Increased batch size during indexing to reduce operation counts.
  4. Data Shuffling Adjustment: Maintained data locality by avoiding random shuffling within the dataset.
  5. Precomputation: Tokenized data during initialization to minimize training-time overhead.
  6. Parallel Loading: Leveraged multiple workers and pinned memory for faster data processing.

These optimizations should significantly reduce the .select(indices) bottleneck and improve overall training efficiency.

2 Likes