Unexpected Keywork Argument

import torch
from transformers import WhisperFeatureExtractor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperProcessor
from peft import get_peft_model, LoraConfig, TaskType
import os
from torch.utils.data import Dataset, DataLoader
import torchaudio
import pandas as pd

Load Feature extractor, Tokenizer, and Processor

feature_extractor = WhisperFeatureExtractor.from_pretrained(“openai/whisper-small”)
tokenizer = WhisperTokenizer.from_pretrained(“openai/whisper-small”, language=“Tamil”, task=“transcribe”)
processor = WhisperProcessor.from_pretrained(“openai/whisper-small”, language=“Tamil”, task=“transcribe”)

Load the model

model = WhisperForConditionalGeneration.from_pretrained(“openai/whisper-small”)

Define the target modules for LoRA based on the actual model structure

target_modules = [
# Encoder layers
“model.encoder.layers.0.self_attn.k_proj”,
“model.encoder.layers.0.self_attn.v_proj”,
“model.encoder.layers.0.self_attn.q_proj”,
“model.encoder.layers.0.self_attn.out_proj”,
“model.encoder.layers.0.fc1”,
“model.encoder.layers.0.fc2”,
“model.encoder.layers.1.self_attn.k_proj”,
“model.encoder.layers.1.self_attn.v_proj”,
“model.encoder.layers.1.self_attn.q_proj”,
“model.encoder.layers.1.self_attn.out_proj”,
“model.encoder.layers.1.fc1”,
“model.encoder.layers.1.fc2”,
“model.encoder.layers.2.self_attn.k_proj”,
“model.encoder.layers.2.self_attn.v_proj”,
“model.encoder.layers.2.self_attn.q_proj”,
“model.encoder.layers.2.self_attn.out_proj”,
“model.encoder.layers.2.fc1”,
“model.encoder.layers.2.fc2”,
“model.encoder.layers.3.self_attn.k_proj”,
“model.encoder.layers.3.self_attn.v_proj”,
“model.encoder.layers.3.self_attn.q_proj”,
“model.encoder.layers.3.self_attn.out_proj”,
“model.encoder.layers.3.fc1”,
“model.encoder.layers.3.fc2”,
“model.encoder.layers.4.self_attn.k_proj”,
“model.encoder.layers.4.self_attn.v_proj”,
“model.encoder.layers.4.self_attn.q_proj”,
“model.encoder.layers.4.self_attn.out_proj”,
“model.encoder.layers.4.fc1”,
“model.encoder.layers.4.fc2”,
“model.encoder.layers.5.self_attn.k_proj”,
“model.encoder.layers.5.self_attn.v_proj”,
“model.encoder.layers.5.self_attn.q_proj”,
“model.encoder.layers.5.self_attn.out_proj”,
“model.encoder.layers.5.fc1”,
“model.encoder.layers.5.fc2”,
“model.encoder.layers.6.self_attn.k_proj”,
“model.encoder.layers.6.self_attn.v_proj”,
“model.encoder.layers.6.self_attn.q_proj”,
“model.encoder.layers.6.self_attn.out_proj”,
“model.encoder.layers.6.fc1”,
“model.encoder.layers.6.fc2”,
“model.encoder.layers.7.self_attn.k_proj”,
“model.encoder.layers.7.self_attn.v_proj”,
“model.encoder.layers.7.self_attn.q_proj”,
“model.encoder.layers.7.self_attn.out_proj”,
“model.encoder.layers.7.fc1”,
“model.encoder.layers.7.fc2”,
“model.encoder.layers.8.self_attn.k_proj”,
“model.encoder.layers.8.self_attn.v_proj”,
“model.encoder.layers.8.self_attn.q_proj”,
“model.encoder.layers.8.self_attn.out_proj”,
“model.encoder.layers.8.fc1”,
“model.encoder.layers.8.fc2”,
“model.encoder.layers.9.self_attn.k_proj”,
“model.encoder.layers.9.self_attn.v_proj”,
“model.encoder.layers.9.self_attn.q_proj”,
“model.encoder.layers.9.self_attn.out_proj”,
“model.encoder.layers.9.fc1”,
“model.encoder.layers.9.fc2”,
“model.encoder.layers.10.self_attn.k_proj”,
“model.encoder.layers.10.self_attn.v_proj”,
“model.encoder.layers.10.self_attn.q_proj”,
“model.encoder.layers.10.self_attn.out_proj”,
“model.encoder.layers.10.fc1”,
“model.encoder.layers.10.fc2”,
“model.encoder.layers.11.self_attn.k_proj”,
“model.encoder.layers.11.self_attn.v_proj”,
“model.encoder.layers.11.self_attn.q_proj”,
“model.encoder.layers.11.self_attn.out_proj”,
“model.encoder.layers.11.fc1”,
“model.encoder.layers.11.fc2”,

# Decoder layers
"model.decoder.layers.0.self_attn.k_proj",
"model.decoder.layers.0.self_attn.v_proj",
"model.decoder.layers.0.self_attn.q_proj",
"model.decoder.layers.0.self_attn.out_proj",
"model.decoder.layers.0.encoder_attn.k_proj",
"model.decoder.layers.0.encoder_attn.v_proj",
"model.decoder.layers.0.encoder_attn.q_proj",
"model.decoder.layers.0.encoder_attn.out_proj",
"model.decoder.layers.0.fc1",
"model.decoder.layers.0.fc2",
"model.decoder.layers.1.self_attn.k_proj",
"model.decoder.layers.1.self_attn.v_proj",
"model.decoder.layers.1.self_attn.q_proj",
"model.decoder.layers.1.self_attn.out_proj",
"model.decoder.layers.1.encoder_attn.k_proj",
"model.decoder.layers.1.encoder_attn.v_proj",
"model.decoder.layers.1.encoder_attn.q_proj",
"model.decoder.layers.1.encoder_attn.out_proj",
"model.decoder.layers.1.fc1",
"model.decoder.layers.1.fc2",
"model.decoder.layers.2.self_attn.k_proj",
"model.decoder.layers.2.self_attn.v_proj",
"model.decoder.layers.2.self_attn.q_proj",
"model.decoder.layers.2.self_attn.out_proj",
"model.decoder.layers.2.encoder_attn.k_proj",
"model.decoder.layers.2.encoder_attn.v_proj",
"model.decoder.layers.2.encoder_attn.q_proj",
"model.decoder.layers.2.encoder_attn.out_proj",
"model.decoder.layers.2.fc1",
"model.decoder.layers.2.fc2",
"model.decoder.layers.3.self_attn.k_proj",
"model.decoder.layers.3.self_attn.v_proj",
"model.decoder.layers.3.self_attn.q_proj",
"model.decoder.layers.3.self_attn.out_proj",
"model.decoder.layers.3.encoder_attn.k_proj",
"model.decoder.layers.3.encoder_attn.v_proj",
"model.decoder.layers.3.encoder_attn.q_proj",
"model.decoder.layers.3.encoder_attn.out_proj",
"model.decoder.layers.3.fc1",
"model.decoder.layers.3.fc2",
"model.decoder.layers.4.self_attn.k_proj",
"model.decoder.layers.4.self_attn.v_proj",
"model.decoder.layers.4.self_attn.q_proj",
"model.decoder.layers.4.self_attn.out_proj",
"model.decoder.layers.4.encoder_attn.k_proj",
"model.decoder.layers.4.encoder_attn.v_proj",
"model.decoder.layers.4.encoder_attn.q_proj",
"model.decoder.layers.4.encoder_attn.out_proj",
"model.decoder.layers.4.fc1",
"model.decoder.layers.4.fc2",
"model.decoder.layers.5.self_attn.k_proj",
"model.decoder.layers.5.self_attn.v_proj",
"model.decoder.layers.5.self_attn.q_proj",
"model.decoder.layers.5.self_attn.out_proj",
"model.decoder.layers.5.encoder_attn.k_proj",
"model.decoder.layers.5.encoder_attn.v_proj",
"model.decoder.layers.5.encoder_attn.q_proj",
"model.decoder.layers.5.encoder_attn.out_proj",
"model.decoder.layers.5.fc1",
"model.decoder.layers.5.fc2",
"model.decoder.layers.6.self_attn.k_proj",
"model.decoder.layers.6.self_attn.v_proj",
"model.decoder.layers.6.self_attn.q_proj",
"model.decoder.layers.6.self_attn.out_proj",
"model.decoder.layers.6.encoder_attn.k_proj",
"model.decoder.layers.6.encoder_attn.v_proj",
"model.decoder.layers.6.encoder_attn.q_proj",
"model.decoder.layers.6.encoder_attn.out_proj",
"model.decoder.layers.6.fc1",
"model.decoder.layers.6.fc2",
"model.decoder.layers.7.self_attn.k_proj",
"model.decoder.layers.7.self_attn.v_proj",
"model.decoder.layers.7.self_attn.q_proj",
"model.decoder.layers.7.self_attn.out_proj",
"model.decoder.layers.7.encoder_attn.k_proj",
"model.decoder.layers.7.encoder_attn.v_proj",
"model.decoder.layers.7.encoder_attn.q_proj",
"model.decoder.layers.7.encoder_attn.out_proj",
"model.decoder.layers.7.fc1",
"model.decoder.layers.7.fc2",
"model.decoder.layers.8.self_attn.k_proj",
"model.decoder.layers.8.self_attn.v_proj",
"model.decoder.layers.8.self_attn.q_proj",
"model.decoder.layers.8.self_attn.out_proj",
"model.decoder.layers.8.encoder_attn.k_proj",
"model.decoder.layers.8.encoder_attn.v_proj",
"model.decoder.layers.8.encoder_attn.q_proj",
"model.decoder.layers.8.encoder_attn.out_proj",
"model.decoder.layers.8.fc1",
"model.decoder.layers.8.fc2",
"model.decoder.layers.9.self_attn.k_proj",
"model.decoder.layers.9.self_attn.v_proj",
"model.decoder.layers.9.self_attn.q_proj",
"model.decoder.layers.9.self_attn.out_proj",
"model.decoder.layers.9.encoder_attn.k_proj",
"model.decoder.layers.9.encoder_attn.v_proj",
"model.decoder.layers.9.encoder_attn.q_proj",
"model.decoder.layers.9.encoder_attn.out_proj",
"model.decoder.layers.9.fc1",
"model.decoder.layers.9.fc2",
"model.decoder.layers.10.self_attn.k_proj",
"model.decoder.layers.10.self_attn.v_proj",
"model.decoder.layers.10.self_attn.q_proj",
"model.decoder.layers.10.self_attn.out_proj",
"model.decoder.layers.10.encoder_attn.k_proj",
"model.decoder.layers.10.encoder_attn.v_proj",
"model.decoder.layers.10.encoder_attn.q_proj",
"model.decoder.layers.10.encoder_attn.out_proj",
"model.decoder.layers.10.fc1",
"model.decoder.layers.10.fc2",
"model.decoder.layers.11.self_attn.k_proj",
"model.decoder.layers.11.self_attn.v_proj",
"model.decoder.layers.11.self_attn.q_proj",
"model.decoder.layers.11.self_attn.out_proj",
"model.decoder.layers.11.encoder_attn.k_proj",
"model.decoder.layers.11.encoder_attn.v_proj",
"model.decoder.layers.11.encoder_attn.q_proj",
"model.decoder.layers.11.encoder_attn.out_proj",
"model.decoder.layers.11.fc1",
"model.decoder.layers.11.fc2",

]

Configure LoRA

config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
r=8,
lora_alpha=16,
target_modules=target_modules,
lora_dropout=0.1,
)

for name, module in model.named_modules():

print(name)

Apply LoRA to the model

lora_model = get_peft_model(model, config)

Set model to training mode

lora_model.train()

class WhisperDataset(Dataset):
def init(self, csv_file, audio_dir, processor):
self.data = pd.read_csv(csv_file)
self.audio_dir = audio_dir
self.processor = processor

def __len__(self):
    return len(self.data)

def __getitem__(self, idx):
    if torch.is_tensor(idx):
        idx = idx.tolist()

    audio_path = f"{self.audio_dir}/{self.data.iloc[idx, 0]}"
    transcription = self.data.iloc[idx, 1]

    # Load audio file
    speech_array, sampling_rate = torchaudio.load(audio_path)

    # Process the audio and transcription
    input_features = self.processor.feature_extractor(speech_array.squeeze().numpy(), sampling_rate=sampling_rate, return_tensors="pt").input_features
    labels = self.processor.tokenizer(transcription, return_tensors="pt").input_ids
    # print(f"Processing index: {idx}")
    # print(f"Transcription: {transcription}")
    # print(f"Input features shape: {input_features.shape}")
    # print(f"Labels (input_ids): {labels}")

    return {
        "input_features": input_features.squeeze(),
        "labels": labels.squeeze()
    }

Define a custom collate function to pad sequences

def collate_fn(batch):
input_features = [item[‘input_features’] for item in batch]
labels = [item[‘labels’] for item in batch]

# Pad sequences
input_features_padded = torch.nn.utils.rnn.pad_sequence(input_features, batch_first=True, padding_value=0)
labels_padded = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=tokenizer.pad_token_id)

return {
    'input_features': input_features_padded,
    'labels': labels_padded
}

Example usage

csv_file = “./dup.csv” # Path to your CSV file
audio_dir = “.” # Path to the directory containing your audio files
processor = WhisperProcessor.from_pretrained(“openai/whisper-small”, language=“Tamil”, task=“transcribe”)

dataset = WhisperDataset(csv_file, audio_dir, processor)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

Define the optimizer

optimizer = torch.optim.Adam(lora_model.parameters(), lr=1e-4)

Fine-tuning loop

num_epochs = 3 # Define the number of epochs
for epoch in range(num_epochs):
for batch in dataloader:
# Forward pass
input_features = batch[‘input_features’].to(lora_model.device)
labels = batch[‘labels’].to(lora_model.device)

    print(f"Batch input features shape: {input_features.shape}")
    print(f"Batch labels shape: {labels.shape}")
    print(f"Batch labels: {labels}")
    # print(input_features.shape)
    outputs = lora_model(input_features=input_features, labels=labels)
    loss = outputs.loss

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

Save the fine-tuned model

lora_model.save_pretrained(“/content/fine_tuned_whisper_with_lora”)

TypeError: WhisperForConditionalGeneration.forward() got an unexpected keyword argument ‘input_ids’

Above error is shown for the code

Same here, it seems that Peft wrapper always calls the model with the input_ids and input_embeds kwargs, which is unecpected by the Whisper model

It looks like get_peft_model is only used to get the more appropriate model based on argument task_typefrom LoraConfig.

Since none of the wrappers work for Whisper, you have to create your own implementation.

So instead of lora_model = get_peft_model(model, config), you can use lora_model = WhisperTuner(model, config).

Where WhisperTuner would be something like this:

class WhisperTuner(PeftModel):
    def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
        super().__init__(model, peft_config, adapter_name)
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
        self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
            self.base_model._prepare_encoder_decoder_kwargs_for_generation
        )

    def forward(
            self,
            attention_mask=None,
            decoder_input_ids=None,
            decoder_attention_mask=None,
            decoder_inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            task_ids=None,
            **kwargs,
    ):
        peft_config = self.active_peft_config
        if not peft_config.is_prompt_learning:
            if peft_config.peft_type == PeftType.POLY:
                kwargs["task_ids"] = task_ids

            with self._enable_peft_forward_hooks(**kwargs):
                kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                return self.base_model(
                    attention_mask=attention_mask,
                    decoder_input_ids=decoder_input_ids,
                    decoder_attention_mask=decoder_attention_mask,
                    decoder_inputs_embeds=decoder_inputs_embeds,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )

    def generate(self, **kwargs):
        peft_config = self.active_peft_config
        self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
        self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
            self._prepare_encoder_decoder_kwargs_for_generation
        )
        try:
            if not peft_config.is_prompt_learning:
                with self._enable_peft_forward_hooks(**kwargs):
                    kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                    outputs = self.base_model.generate(**kwargs)
        except:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
                self.base_model_prepare_encoder_decoder_kwargs_for_generation
            )
            raise
        else:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
                self.base_model_prepare_encoder_decoder_kwargs_for_generation
            )
            return outputs

    def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs):
        peft_config = self.active_peft_config
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if peft_config.peft_type == PeftType.POLY:
            model_kwargs["task_ids"] = task_ids
        if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
            batch_size = model_kwargs["decoder_input_ids"].shape[0]
            past_key_values = self.get_prompt(batch_size)
            model_kwargs["past_key_values"] = past_key_values

        return model_kwargs

(I just copied an implementation of PeftModel and deleted all paths and entries that are not needed for whisper, there’s probably room for improvement in this code)

When use openai/whisper-large-v3 and set up no_speech_detection will also trigger the issue of unexpected keyword argument of input_ids. The root cause of this issue is the _setup_no_speech_detection function in WhisperGenerationMixin add a input_ids argument which is not acceptable to the forward func of Whisper. Here is a fix PR fix unexpected kws of input_ids when setup no speech detection of whisper by zhaozhenyu-newsbreak · Pull Request #36809 · huggingface/transformers · GitHub, waiting for review

1 Like