I’m mapping my Audio dataset to the function below:
import base64
from io import BytesIO
import torchaudio
import torch
import re
def prepare_dataset(batch):
# Load MP3 audio bytes
audio_bytes = batch['audio']['bytes']
# Check if audio_bytes is empty or None
if not audio_bytes:
raise ValueError("Audio bytes are missing or null")
# Wrap the audio bytes in a BytesIO object
audio_file = BytesIO(audio_bytes)
# Use torchaudio to load the MP3 audio
# Ensure that the appropriate backend is used
# You can use "sox_io" or "ffmpeg" depending on your installation
torchaudio.set_audio_backend("ffmpeg") # or "ffmpeg" if FFmpeg is installed
incoming_waveform, sample_rate = torchaudio.load(audio_file, format='mp3')
except Exception as e:
raise ValueError(f"Failed to load MP3 audio with torchaudio: {e}")
# Optional resampling to 16kHz if required
target_sample_rate = 16000
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
incoming_waveform = resampler(incoming_waveform)
# Compute input features using your processor's feature extractor
batch["input_features"] = processor.feature_extractor(
incoming_waveform.squeeze().numpy(), sampling_rate=target_sample_rate
except Exception as e:
raise ValueError(f"Error extracting features: {e}")
# Compute the input length in seconds
batch["input_length"] = incoming_waveform.size(1) / target_sample_rate
# Process transcription and labels
# Try both 'transcription' and 'sentence' keys if applicable
transcription = batch.get("transcription")
# If 'transcription' is not available, you can try 'sentence' or another key
# transcription = transcription or batch.get("sentence")
# Check if transcription is missing or null
if not transcription or not isinstance(transcription, str):
raise ValueError("Transcription is missing or null")
# Strip leading and trailing whitespace
transcription = transcription.strip()
if not transcription:
raise ValueError("Transcription is empty after stripping")
if do_lower_case:
transcription = transcription.lower()
if do_remove_punctuation:
transcription = re.sub(punctuation_to_remove_regex, " ", transcription).strip()
# Encode target text to label ids
batch["labels"] = processor.tokenizer(transcription).input_ids
except Exception as e:
raise ValueError(f"Error tokenizing transcription: {e}")
return batch
except ValueError as ve:
print(f"Skipping corrupted data: {ve}")
return None # Returning None will exclude this batch from the final dataset
When calling the mapping function, the arrow writer will raise the following error:
vectorized_datasets = dataset_2.map(
File /usr/local/lib/python3.11/dist-packages/datasets/arrow_writer.py:450, in <listcomp>(.0)
447 batch_examples[col] = pa.concat_arrays(arrays)
448 else:
449 batch_examples[col] = [
--> 450 row[0][col].to_pylist()[0] if isinstance(row[0][col], (pa.Array, pa.ChunkedArray)) else row[0][col]
451 for row in self.current_examples
452 ]
453 self.write_batch(batch_examples=batch_examples)
454 self.current_examples = []
TypeError: 'NoneType' object is not subscriptable
This happens specifically at the end of a shard (1000th example in my case):
File /usr/local/lib/python3.11/dist-packages/datasets/arrow_dataset.py:3035, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
3029 if transformed_dataset is None:
3030 with hf_tqdm(
3031 unit=" examples",
3032 total=pbar_total,
3033 desc=desc or "Map",
3034 ) as pbar:
-> 3035 for rank, done, content in Dataset._map_single(**dataset_kwargs):
3036 if done:
3037 shards_done += 1
File /usr/local/lib/python3.11/dist-packages/datasets/arrow_dataset.py:3473, in Dataset._map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)
3471 if update_data:
3472 if writer is not None:
-> 3473 writer.finalize()
3474 if tmp_file is not None:
3475 tmp_file.close()