Custom 20GB Arrow dataset very slow to train

I think lhonestq will be able to help you with the details, but for now, it seems that Arrow is not recommended for processing large datasets.

The following are general improvement suggestions from Hugging Chat.


To address the slow .select(indices) call on the Arrow dataset, here are the step-by-step optimizations:

Step 1: Replace List-Based Indices with Slice Objects

Why: Using slices for contiguous chunks enhances performance because Arrow can process ranges more efficiently than individual indices.

How:

  • Modify the index_mapping to store start and end indices instead of lists.
  • Use slices in .select() to retrieve contiguous data chunks.

Code Adjustment:

# Modify __getitem__ to use slices
def __getitem__(self, idx):
    user_id, chunk_idx = self.index_mapping[idx]
    start_idx, end_idx = self.user_boundaries[user_id]
    offset_start = start_idx + chunk_idx * self.chunk_size
    offset_end = min(offset_start + self.chunk_size, start_idx + end_idx)
    events = self.arrow_ds.select(slice(offset_start, offset_end))

Step 2: Enable Dataset Caching

Why: Ensures data stays in memory, reducing disk I/O and speeding up repeated access.

How:

  • Enable caching when loading the dataset.
  • Use .set_format to read data as NumPy arrays for faster access.

Code Adjustment:

# Load the Arrow dataset with caching enabled
arrow_ds = load_from_disk("path/to/arrow_dataset")
arrow_ds = arrow_ds.cache_files()
arrow_ds = arrow_ds.with_format("numpy")

Step 3: Optimize Index Creation

Why: Batch processing during index creation reduces the number of operations and leverages Arrow’s efficiency for contiguous data.

How:

  • Increase the batch size during dataset initialization to process larger chunks at once.

Code Adjustment:

# Increase batch_size for more efficient indexing
train_dataset = ArrowBrowsingDataset(train_ds, chunk_size=512, batch_size=100000)

Step 4: Adjust Data Shuffling Strategy

Why: Shuffling can disrupt contiguous data access, leading to slower performance.

How:

  • Avoid random shuffling of the index_mapping to maintain data locality.
  • If randomness is needed, consider alternative methods or shuffle after indexing.

Code Adjustment:

# Remove random.shuffle if it disrupts data locality
# random.shuffle(self.index_mapping)

Step 5: Precompute Tokens During Initialization

Why: Reduces the overhead of tokenization during training, speeding up data retrieval.

How:

  • Precompute tokens for each event during dataset initialization.
  • Store tokenized sequences in the dataset for faster access.

Code Adjustment:

# Precompute tokens in initialize
class ArrowBrowsingDataset(Dataset):
    def __init__(self, arrow_dataset, chunk_size=512, batch_size=100000):
        # ... existing code ...
        events = self.arrow_ds.select(slice(start, end))
        # Precompute tokens
        tokenized_events = [tokenize_event(event) for event in events]
        self.tokenized_data.append(tokenized_events)

Step 6: Parallelize Data Loading

Why: Utilizes multiple CPU cores to load data faster, improving overall efficiency.

How:

  • Increase the num_workers parameter in the DataLoader.
  • Use pinned memory to speed up data transfer to the GPU.

Code Adjustment:

# Adjust DataLoader settings
trainer = Trainer(
    # ... other parameters ...
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_batch_size=8,
    num_workers=4,
    pin_memory=True
)

Final Code Implementation

# Modified __getitem__ using slices
class ArrowBrowsingDataset(Dataset):
    def __getitem__(self, idx):
        user_id, chunk_idx = self.index_mapping[idx]
        start, count, label = self.user_boundaries[user_id]
        offset_start = start + chunk_idx * self.chunk_size
        offset_end = min(offset_start + self.chunk_size, start + count)
        events = self.arrow_ds.select(slice(offset_start, offset_end))
        # Precomputed tokens for faster access
        return {"input_ids": self.tokenized_data[user_id][chunk_idx], "labels": label}

# Load dataset with optimizations
arrow_ds = load_from_disk("path/to/arrow_dataset")
arrow_ds = arrow_ds.cache_files().with_format("numpy")

# Create datasets with optimized parameters
train_dataset = ArrowBrowsingDataset(train_ds, chunk_size=512, batch_size=100000)
test_dataset = ArrowBrowsingDataset(test_ds, chunk_size=512, batch_size=100000)

# Initialize Trainer with parallel loading
training_args = TrainingArguments(
    # ... other parameters ...
    per_device_train_batch_size=8,
    num_workers=4,
    pin_memory=True,
)

Summary of Optimizations

  1. Slice-Based Selection: Replaced list indices with slices for faster contiguous data retrieval.
  2. Caching and NumPy Format: Enabled dataset caching and NumPy access for efficient in-memory operations.
  3. Efficient Indexing: Increased batch size during indexing to reduce operation counts.
  4. Data Shuffling Adjustment: Maintained data locality by avoiding random shuffling within the dataset.
  5. Precomputation: Tokenized data during initialization to minimize training-time overhead.
  6. Parallel Loading: Leveraged multiple workers and pinned memory for faster data processing.

These optimizations should significantly reduce the .select(indices) bottleneck and improve overall training efficiency.

2 Likes