Remove columns from streamable datasets doesn't work

Hey guys, this is a portion of my trainer’s code:

def prepare_dataset(batch):
    # load and (possibly) resample audio datato 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    # compute input length of audio sample in seconds
    batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
    # optional pre-processing steps
    transcription = batch["sentence"]
    # encode target text to label ids
    batch["labels"] = processor.tokenizer(transcription).input_ids
    batch["labels_length"] = len(batch["labels"])
    return batch

vectorized_datasets =, remove_columns=list(next(iter(raw_datasets.values())).features)).with_format("torch")

max_input_length = 30.0
max_labels_length = 500
max_input_length = max_input_length * 16000

def filter_inputs(input_length):
    """Filter inputs with zero input length or longer than 30s"""
    return 0 < input_length < max_input_length

def filter_labels(labels_length):
    """Filter empty label sequences"""
    return 0 < labels_length < max_labels_length

vectorized_datasets["train"] = vectorized_datasets['train'].filter(filter_inputs, input_columns=["input_length"])
vectorized_datasets["train"] = vectorized_datasets['train'].filter(filter_labels, input_columns=["labels_length"])
vectorized_datasets.remove_columns(['input_length', 'labels_length'])

What I’m doing here is, after I prepare the dataset, I try to remove the unused columns.
This is actually, just me trying to debug another problem of adding filters to my dataset which I fail to debug.

I’m trying to apply filter_labels and filter_inputs which remove entries that are above 500 tokens.
But when I run the trainer, I get the same issue as described here Trainer RuntimeError: The size of tensor a (462) must match the size of tensor b (448) at non-singleton dimension 1

So my question is, how to debug a stream dataset to make sure that filters are applied?
Thank you.

Hi ! You should do

vectorized_datasets = vectorized_datasets.remove_columns(['input_length', 'labels_length'])

this is what worked for me:

    # -- Get data set
    def my_load_dataset(path, name, data_files=data_files):
        print(f'{path=} {name=} {streaming=}')
        if path == 'json' or path == 'bin' or path == 'csv':
            return load_dataset(path, data_files=data_files_prefix+name, streaming=streaming, split="train").with_format("torch")
        elif path == 'parquet':
            return load_dataset(path, data_files=data_files, streaming=streaming, split="train").with_format("torch")
            return load_dataset(path, name, streaming=streaming, split="train").with_format("torch")
    # - get data set for real now
    if isinstance(path, str):
        dataset = my_load_dataset(path, name)
        # -Interleaving datasets
        print('- Interleaving datasets')
        datasets = [my_load_dataset(path, name, data_files).with_format("torch") for path, name, data_files in zip(path, name, data_files)]
        [print(f'{dataset.description=}') for dataset in datasets]  # print description if available
        # - make sure all datasets have the same columns to avoid interleave to complain
        columns_to_remove = [col for dataset in datasets for col in dataset.column_names if col != 'text']
        columns_to_remove = list(set(columns_to_remove))  # remove duplicates
        datasets = [dataset.remove_columns(columns_to_remove) for dataset in datasets]
        # - interleave
        dataset = interleave_datasets(datasets, probabilities)
1 Like

Omg! Thanks I was looking for this, appreciate your contribution.

1 Like