Unexpected Keywork Argument

import torch
from transformers import WhisperFeatureExtractor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperProcessor
from peft import get_peft_model, LoraConfig, TaskType
import os
from torch.utils.data import Dataset, DataLoader
import torchaudio
import pandas as pd

Load Feature extractor, Tokenizer, and Processor

feature_extractor = WhisperFeatureExtractor.from_pretrained(“openai/whisper-small”)
tokenizer = WhisperTokenizer.from_pretrained(“openai/whisper-small”, language=“Tamil”, task=“transcribe”)
processor = WhisperProcessor.from_pretrained(“openai/whisper-small”, language=“Tamil”, task=“transcribe”)

Load the model

model = WhisperForConditionalGeneration.from_pretrained(“openai/whisper-small”)

Define the target modules for LoRA based on the actual model structure

target_modules = [
# Encoder layers
“model.encoder.layers.0.self_attn.k_proj”,
“model.encoder.layers.0.self_attn.v_proj”,
“model.encoder.layers.0.self_attn.q_proj”,
“model.encoder.layers.0.self_attn.out_proj”,
“model.encoder.layers.0.fc1”,
“model.encoder.layers.0.fc2”,
“model.encoder.layers.1.self_attn.k_proj”,
“model.encoder.layers.1.self_attn.v_proj”,
“model.encoder.layers.1.self_attn.q_proj”,
“model.encoder.layers.1.self_attn.out_proj”,
“model.encoder.layers.1.fc1”,
“model.encoder.layers.1.fc2”,
“model.encoder.layers.2.self_attn.k_proj”,
“model.encoder.layers.2.self_attn.v_proj”,
“model.encoder.layers.2.self_attn.q_proj”,
“model.encoder.layers.2.self_attn.out_proj”,
“model.encoder.layers.2.fc1”,
“model.encoder.layers.2.fc2”,
“model.encoder.layers.3.self_attn.k_proj”,
“model.encoder.layers.3.self_attn.v_proj”,
“model.encoder.layers.3.self_attn.q_proj”,
“model.encoder.layers.3.self_attn.out_proj”,
“model.encoder.layers.3.fc1”,
“model.encoder.layers.3.fc2”,
“model.encoder.layers.4.self_attn.k_proj”,
“model.encoder.layers.4.self_attn.v_proj”,
“model.encoder.layers.4.self_attn.q_proj”,
“model.encoder.layers.4.self_attn.out_proj”,
“model.encoder.layers.4.fc1”,
“model.encoder.layers.4.fc2”,
“model.encoder.layers.5.self_attn.k_proj”,
“model.encoder.layers.5.self_attn.v_proj”,
“model.encoder.layers.5.self_attn.q_proj”,
“model.encoder.layers.5.self_attn.out_proj”,
“model.encoder.layers.5.fc1”,
“model.encoder.layers.5.fc2”,
“model.encoder.layers.6.self_attn.k_proj”,
“model.encoder.layers.6.self_attn.v_proj”,
“model.encoder.layers.6.self_attn.q_proj”,
“model.encoder.layers.6.self_attn.out_proj”,
“model.encoder.layers.6.fc1”,
“model.encoder.layers.6.fc2”,
“model.encoder.layers.7.self_attn.k_proj”,
“model.encoder.layers.7.self_attn.v_proj”,
“model.encoder.layers.7.self_attn.q_proj”,
“model.encoder.layers.7.self_attn.out_proj”,
“model.encoder.layers.7.fc1”,
“model.encoder.layers.7.fc2”,
“model.encoder.layers.8.self_attn.k_proj”,
“model.encoder.layers.8.self_attn.v_proj”,
“model.encoder.layers.8.self_attn.q_proj”,
“model.encoder.layers.8.self_attn.out_proj”,
“model.encoder.layers.8.fc1”,
“model.encoder.layers.8.fc2”,
“model.encoder.layers.9.self_attn.k_proj”,
“model.encoder.layers.9.self_attn.v_proj”,
“model.encoder.layers.9.self_attn.q_proj”,
“model.encoder.layers.9.self_attn.out_proj”,
“model.encoder.layers.9.fc1”,
“model.encoder.layers.9.fc2”,
“model.encoder.layers.10.self_attn.k_proj”,
“model.encoder.layers.10.self_attn.v_proj”,
“model.encoder.layers.10.self_attn.q_proj”,
“model.encoder.layers.10.self_attn.out_proj”,
“model.encoder.layers.10.fc1”,
“model.encoder.layers.10.fc2”,
“model.encoder.layers.11.self_attn.k_proj”,
“model.encoder.layers.11.self_attn.v_proj”,
“model.encoder.layers.11.self_attn.q_proj”,
“model.encoder.layers.11.self_attn.out_proj”,
“model.encoder.layers.11.fc1”,
“model.encoder.layers.11.fc2”,

# Decoder layers
"model.decoder.layers.0.self_attn.k_proj",
"model.decoder.layers.0.self_attn.v_proj",
"model.decoder.layers.0.self_attn.q_proj",
"model.decoder.layers.0.self_attn.out_proj",
"model.decoder.layers.0.encoder_attn.k_proj",
"model.decoder.layers.0.encoder_attn.v_proj",
"model.decoder.layers.0.encoder_attn.q_proj",
"model.decoder.layers.0.encoder_attn.out_proj",
"model.decoder.layers.0.fc1",
"model.decoder.layers.0.fc2",
"model.decoder.layers.1.self_attn.k_proj",
"model.decoder.layers.1.self_attn.v_proj",
"model.decoder.layers.1.self_attn.q_proj",
"model.decoder.layers.1.self_attn.out_proj",
"model.decoder.layers.1.encoder_attn.k_proj",
"model.decoder.layers.1.encoder_attn.v_proj",
"model.decoder.layers.1.encoder_attn.q_proj",
"model.decoder.layers.1.encoder_attn.out_proj",
"model.decoder.layers.1.fc1",
"model.decoder.layers.1.fc2",
"model.decoder.layers.2.self_attn.k_proj",
"model.decoder.layers.2.self_attn.v_proj",
"model.decoder.layers.2.self_attn.q_proj",
"model.decoder.layers.2.self_attn.out_proj",
"model.decoder.layers.2.encoder_attn.k_proj",
"model.decoder.layers.2.encoder_attn.v_proj",
"model.decoder.layers.2.encoder_attn.q_proj",
"model.decoder.layers.2.encoder_attn.out_proj",
"model.decoder.layers.2.fc1",
"model.decoder.layers.2.fc2",
"model.decoder.layers.3.self_attn.k_proj",
"model.decoder.layers.3.self_attn.v_proj",
"model.decoder.layers.3.self_attn.q_proj",
"model.decoder.layers.3.self_attn.out_proj",
"model.decoder.layers.3.encoder_attn.k_proj",
"model.decoder.layers.3.encoder_attn.v_proj",
"model.decoder.layers.3.encoder_attn.q_proj",
"model.decoder.layers.3.encoder_attn.out_proj",
"model.decoder.layers.3.fc1",
"model.decoder.layers.3.fc2",
"model.decoder.layers.4.self_attn.k_proj",
"model.decoder.layers.4.self_attn.v_proj",
"model.decoder.layers.4.self_attn.q_proj",
"model.decoder.layers.4.self_attn.out_proj",
"model.decoder.layers.4.encoder_attn.k_proj",
"model.decoder.layers.4.encoder_attn.v_proj",
"model.decoder.layers.4.encoder_attn.q_proj",
"model.decoder.layers.4.encoder_attn.out_proj",
"model.decoder.layers.4.fc1",
"model.decoder.layers.4.fc2",
"model.decoder.layers.5.self_attn.k_proj",
"model.decoder.layers.5.self_attn.v_proj",
"model.decoder.layers.5.self_attn.q_proj",
"model.decoder.layers.5.self_attn.out_proj",
"model.decoder.layers.5.encoder_attn.k_proj",
"model.decoder.layers.5.encoder_attn.v_proj",
"model.decoder.layers.5.encoder_attn.q_proj",
"model.decoder.layers.5.encoder_attn.out_proj",
"model.decoder.layers.5.fc1",
"model.decoder.layers.5.fc2",
"model.decoder.layers.6.self_attn.k_proj",
"model.decoder.layers.6.self_attn.v_proj",
"model.decoder.layers.6.self_attn.q_proj",
"model.decoder.layers.6.self_attn.out_proj",
"model.decoder.layers.6.encoder_attn.k_proj",
"model.decoder.layers.6.encoder_attn.v_proj",
"model.decoder.layers.6.encoder_attn.q_proj",
"model.decoder.layers.6.encoder_attn.out_proj",
"model.decoder.layers.6.fc1",
"model.decoder.layers.6.fc2",
"model.decoder.layers.7.self_attn.k_proj",
"model.decoder.layers.7.self_attn.v_proj",
"model.decoder.layers.7.self_attn.q_proj",
"model.decoder.layers.7.self_attn.out_proj",
"model.decoder.layers.7.encoder_attn.k_proj",
"model.decoder.layers.7.encoder_attn.v_proj",
"model.decoder.layers.7.encoder_attn.q_proj",
"model.decoder.layers.7.encoder_attn.out_proj",
"model.decoder.layers.7.fc1",
"model.decoder.layers.7.fc2",
"model.decoder.layers.8.self_attn.k_proj",
"model.decoder.layers.8.self_attn.v_proj",
"model.decoder.layers.8.self_attn.q_proj",
"model.decoder.layers.8.self_attn.out_proj",
"model.decoder.layers.8.encoder_attn.k_proj",
"model.decoder.layers.8.encoder_attn.v_proj",
"model.decoder.layers.8.encoder_attn.q_proj",
"model.decoder.layers.8.encoder_attn.out_proj",
"model.decoder.layers.8.fc1",
"model.decoder.layers.8.fc2",
"model.decoder.layers.9.self_attn.k_proj",
"model.decoder.layers.9.self_attn.v_proj",
"model.decoder.layers.9.self_attn.q_proj",
"model.decoder.layers.9.self_attn.out_proj",
"model.decoder.layers.9.encoder_attn.k_proj",
"model.decoder.layers.9.encoder_attn.v_proj",
"model.decoder.layers.9.encoder_attn.q_proj",
"model.decoder.layers.9.encoder_attn.out_proj",
"model.decoder.layers.9.fc1",
"model.decoder.layers.9.fc2",
"model.decoder.layers.10.self_attn.k_proj",
"model.decoder.layers.10.self_attn.v_proj",
"model.decoder.layers.10.self_attn.q_proj",
"model.decoder.layers.10.self_attn.out_proj",
"model.decoder.layers.10.encoder_attn.k_proj",
"model.decoder.layers.10.encoder_attn.v_proj",
"model.decoder.layers.10.encoder_attn.q_proj",
"model.decoder.layers.10.encoder_attn.out_proj",
"model.decoder.layers.10.fc1",
"model.decoder.layers.10.fc2",
"model.decoder.layers.11.self_attn.k_proj",
"model.decoder.layers.11.self_attn.v_proj",
"model.decoder.layers.11.self_attn.q_proj",
"model.decoder.layers.11.self_attn.out_proj",
"model.decoder.layers.11.encoder_attn.k_proj",
"model.decoder.layers.11.encoder_attn.v_proj",
"model.decoder.layers.11.encoder_attn.q_proj",
"model.decoder.layers.11.encoder_attn.out_proj",
"model.decoder.layers.11.fc1",
"model.decoder.layers.11.fc2",

]

Configure LoRA

config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
r=8,
lora_alpha=16,
target_modules=target_modules,
lora_dropout=0.1,
)

for name, module in model.named_modules():

print(name)

Apply LoRA to the model

lora_model = get_peft_model(model, config)

Set model to training mode

lora_model.train()

class WhisperDataset(Dataset):
def init(self, csv_file, audio_dir, processor):
self.data = pd.read_csv(csv_file)
self.audio_dir = audio_dir
self.processor = processor

def __len__(self):
    return len(self.data)

def __getitem__(self, idx):
    if torch.is_tensor(idx):
        idx = idx.tolist()

    audio_path = f"{self.audio_dir}/{self.data.iloc[idx, 0]}"
    transcription = self.data.iloc[idx, 1]

    # Load audio file
    speech_array, sampling_rate = torchaudio.load(audio_path)

    # Process the audio and transcription
    input_features = self.processor.feature_extractor(speech_array.squeeze().numpy(), sampling_rate=sampling_rate, return_tensors="pt").input_features
    labels = self.processor.tokenizer(transcription, return_tensors="pt").input_ids
    # print(f"Processing index: {idx}")
    # print(f"Transcription: {transcription}")
    # print(f"Input features shape: {input_features.shape}")
    # print(f"Labels (input_ids): {labels}")

    return {
        "input_features": input_features.squeeze(),
        "labels": labels.squeeze()
    }

Define a custom collate function to pad sequences

def collate_fn(batch):
input_features = [item[‘input_features’] for item in batch]
labels = [item[‘labels’] for item in batch]

# Pad sequences
input_features_padded = torch.nn.utils.rnn.pad_sequence(input_features, batch_first=True, padding_value=0)
labels_padded = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=tokenizer.pad_token_id)

return {
    'input_features': input_features_padded,
    'labels': labels_padded
}

Example usage

csv_file = “./dup.csv” # Path to your CSV file
audio_dir = “.” # Path to the directory containing your audio files
processor = WhisperProcessor.from_pretrained(“openai/whisper-small”, language=“Tamil”, task=“transcribe”)

dataset = WhisperDataset(csv_file, audio_dir, processor)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

Define the optimizer

optimizer = torch.optim.Adam(lora_model.parameters(), lr=1e-4)

Fine-tuning loop

num_epochs = 3 # Define the number of epochs
for epoch in range(num_epochs):
for batch in dataloader:
# Forward pass
input_features = batch[‘input_features’].to(lora_model.device)
labels = batch[‘labels’].to(lora_model.device)

    print(f"Batch input features shape: {input_features.shape}")
    print(f"Batch labels shape: {labels.shape}")
    print(f"Batch labels: {labels}")
    # print(input_features.shape)
    outputs = lora_model(input_features=input_features, labels=labels)
    loss = outputs.loss

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

Save the fine-tuned model

lora_model.save_pretrained(“/content/fine_tuned_whisper_with_lora”)

TypeError: WhisperForConditionalGeneration.forward() got an unexpected keyword argument ‘input_ids’

Above error is shown for the code

Same here, it seems that Peft wrapper always calls the model with the input_ids and input_embeds kwargs, which is unecpected by the Whisper model

It looks like get_peft_model is only used to get the more appropriate model based on argument task_typefrom LoraConfig.

Since none of the wrappers work for Whisper, you have to create your own implementation.

So instead of lora_model = get_peft_model(model, config), you can use lora_model = WhisperTuner(model, config).

Where WhisperTuner would be something like this:

class WhisperTuner(PeftModel):
    def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
        super().__init__(model, peft_config, adapter_name)
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
        self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
            self.base_model._prepare_encoder_decoder_kwargs_for_generation
        )

    def forward(
            self,
            attention_mask=None,
            decoder_input_ids=None,
            decoder_attention_mask=None,
            decoder_inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            task_ids=None,
            **kwargs,
    ):
        peft_config = self.active_peft_config
        if not peft_config.is_prompt_learning:
            if peft_config.peft_type == PeftType.POLY:
                kwargs["task_ids"] = task_ids

            with self._enable_peft_forward_hooks(**kwargs):
                kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                return self.base_model(
                    attention_mask=attention_mask,
                    decoder_input_ids=decoder_input_ids,
                    decoder_attention_mask=decoder_attention_mask,
                    decoder_inputs_embeds=decoder_inputs_embeds,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )

    def generate(self, **kwargs):
        peft_config = self.active_peft_config
        self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
        self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
            self._prepare_encoder_decoder_kwargs_for_generation
        )
        try:
            if not peft_config.is_prompt_learning:
                with self._enable_peft_forward_hooks(**kwargs):
                    kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                    outputs = self.base_model.generate(**kwargs)
        except:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
                self.base_model_prepare_encoder_decoder_kwargs_for_generation
            )
            raise
        else:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
                self.base_model_prepare_encoder_decoder_kwargs_for_generation
            )
            return outputs

    def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs):
        peft_config = self.active_peft_config
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if peft_config.peft_type == PeftType.POLY:
            model_kwargs["task_ids"] = task_ids
        if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
            batch_size = model_kwargs["decoder_input_ids"].shape[0]
            past_key_values = self.get_prompt(batch_size)
            model_kwargs["past_key_values"] = past_key_values

        return model_kwargs

(I just copied an implementation of PeftModel and deleted all paths and entries that are not needed for whisper, there’s probably room for improvement in this code)