Remove columns from streamable datasets doesn't work

Hey guys, this is a portion of my trainer’s code:

def prepare_dataset(batch):
    # load and (possibly) resample audio datato 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    # compute input length of audio sample in seconds
    batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
    
    # optional pre-processing steps
    transcription = batch["sentence"]
    
    # encode target text to label ids
    batch["labels"] = processor.tokenizer(transcription).input_ids
    batch["labels_length"] = len(batch["labels"])
    return batch



vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=list(next(iter(raw_datasets.values())).features)).with_format("torch")



max_input_length = 30.0
max_labels_length = 500
max_input_length = max_input_length * 16000


def filter_inputs(input_length):
    """Filter inputs with zero input length or longer than 30s"""
    return 0 < input_length < max_input_length

def filter_labels(labels_length):
    """Filter empty label sequences"""
    return 0 < labels_length < max_labels_length


vectorized_datasets["train"] = vectorized_datasets['train'].filter(filter_inputs, input_columns=["input_length"])
vectorized_datasets["train"] = vectorized_datasets['train'].filter(filter_labels, input_columns=["labels_length"])
vectorized_datasets.remove_columns(['input_length', 'labels_length'])

What I’m doing here is, after I prepare the dataset, I try to remove the unused columns.
This is actually, just me trying to debug another problem of adding filters to my dataset which I fail to debug.

I’m trying to apply filter_labels and filter_inputs which remove entries that are above 500 tokens.
But when I run the trainer, I get the same issue as described here Trainer RuntimeError: The size of tensor a (462) must match the size of tensor b (448) at non-singleton dimension 1

So my question is, how to debug a stream dataset to make sure that filters are applied?
Thank you.

Hi ! You should do

vectorized_datasets = vectorized_datasets.remove_columns(['input_length', 'labels_length'])