Multilabel Audio Classification Training size mismatch

After running the train function I keep getting a size mismatch

trainer.train()
ValueError: Expected input batch_size (4) to match target batch_size (20).

Here is my setup

My Dataset loader class

class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, file_paths, labels, feature_extractor, max_audio_length_seconds=30):
        self.file_paths = file_paths
        self.labels = labels
        self.feature_extractor = feature_extractor
        self.audio_sample_rate = feature_extractor.sampling_rate
        self.max_audio_length_seconds = max_audio_length_seconds
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        # Load your audio file
        waveform, original_sr = librosa.load(self.file_paths[idx], sr=None)
        waveform = librosa.resample(waveform, orig_sr=original_sr, target_sr=self.audio_sample_rate)
        
        
        # Encode the audio waveform with the feature extractor
        inputs = self.feature_extractor(
            waveform, 
            sampling_rate=self.audio_sample_rate, 
            max_length=self.max_audio_length_seconds*self.audio_sample_rate, 
            truncation=True, 
            padding="max_length",
            return_tensors="pt",
        )
        
        input_values = inputs.input_values.squeeze()  # Remove batch dimension added by return_tensors
        labels = torch.tensor(self.labels[idx], dtype=torch.float)

        return {"input_values": input_values, "labels": labels}

feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_TYPE)
torch.cuda.empty_cache() # PyTorch thing
dataset = AudioDataset(
    file_paths=files, 
    labels=labels, 
    feature_extractor=feature_extractor
)

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

My model init

model = AutoModelForAudioClassification.from_pretrained(
    MODEL_TYPE, 
    num_labels = len(label_columns), 
    problem_type="multi_label_classification", 
    id2label={idx:label for idx, label in enumerate(label_columns)}, 
    label2id={label:idx for idx, label in enumerate(label_columns)}
)

Trainer Args:

training_args = TrainingArguments(
    output_dir=os.path.join(temp, 'output'),
    overwrite_output_dir = True,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=5,
    logging_dir=os.path.join('temp', 'logs'),
    learning_rate=LEARNING_RATE,
    load_best_model_at_end=True,
    metric_for_best_model='auc_score_macro',
    disable_tqdm=False,
    optim="adamw_torch",
    fp16=True,
    run_name=MODEL_NAME
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

I am not sure where the mismatch is I have some additional testing, but it seems an issue with just the trainer class.

MODEL_TYPE = “facebook/wav2vec2-base”
Number of output classes = 5

I tried this as well to make sure

from torch.utils.data import DataLoader

# Assuming train_dataset is already defined
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Load one batch and inspect
for batch in train_loader:
    print(f"Input Values Shape: {batch['input_values'].shape}")  # Output was [4, 480000]
    print(f"Labels Shape: {batch['labels'].shape}")  # Output was [4, 5] for multi-label classification
    break

Hi,

The Trainer API creates batches for you using the default data collator, see here: transformers/src/transformers/data/data_collator.py at v4.38.1 · huggingface/transformers · GitHub.

You might check whether this one is appropriate for your data. Also note that your labels need to be of type “float” in order to work with the BCEWithLogitsLoss (they should not be of type “long”).

I just ran the function and here is to verify with a sample of data

# Take sample
samples = [train_dataset[i] for i in range(4)]  # Adjust the range as necessary

# Collate the sample batch
collated_batch = torch_default_data_collator(samples)

# Inspect the batched data
print(f"Batched Input Values Shape: {collated_batch['input_values'].shape}") # Expected: [4, seq_length]
print(f"Batched Labels Shape: {collated_batch['labels'].shape}") # Expected: [4, 5] for multi-label classification
print(f"The datatype of the data: {collated_batch['labels'][0][0].dtype}")

Batched Input Values Shape: torch.Size([4, 480000])
Batched Labels Shape: torch.Size([4, 5])
The datatype of the data: torch.float32

Seems like everything is in order?