How to train Wav2Vec2 in LoRA?

Hi! I want to finetune Wav2Vec2 model using Trainer and LoRA. However, I am getting an error in the data collator while running the code. The code is given below:


class W2v2Dataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df
        self.pathes = df['id'].values
        self.sentences = df['normalized'].values
        # self.resampler = tat.Resample(32000, SR)
​
    def __getitem__(self, idx):
        apath = f'/kaggle/input/bengaliai-speech/train_mp3s/{self.pathes[idx]}.mp3'
        # waveform, sample_rate = torchaudio.load(apath, format="mp3")
        waveform, sample_rate = librosa.load(apath, sr=16000)
        # waveform = self.resampler(waveform)
        batch = dict()
        y = processor(waveform.reshape(-1), sampling_rate=SR).input_values[0] 
        batch["input_values"] = y
        with processor.as_target_processor():
            batch["labels"] = processor(self.sentences[idx]).input_ids       
        
        return batch
​
    def __len__(self):
        return len(self.df)

train_dataset = W2v2Dataset(train)
valid_dataset = W2v2Dataset(valid)

processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)
vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}



@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch


data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)


wer_metric = load_metric("wer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


model = Wav2Vec2ForCTC.from_pretrained(
    MODEL_PATH,
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    # gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ctc_zero_infinity=True,
    diversity_loss_weight=100 
)


model.freeze_feature_extractor()

## LoRA Config
peft_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["k_proj", "q_proj"],
)

model = get_peft_model(model, peft_config)

model.to(device)


training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    group_by_length=False,
    lr_scheduler_type='cosine',
    weight_decay=0.01,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    optim="adamw_bnb_8bit",
    # dataloader_pin_memory=True,
    # dataloader_num_workers=4,
    save_strategy="steps",
    num_train_epochs=1,
    # max_steps=15000, # you can change to "num_train_epochs"
    fp16=True,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=1000,
    learning_rate=5e-5,
    warmup_steps=600,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    prediction_loss_only=False,
    auto_find_batch_size=True,
    report_to="none"
)


trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=processor.feature_extractor,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)


trainer.train()


╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:1                                                                                    │
│                                                                                                  │
│ ❱ 1 trainer.train()                                                                              │
│   2                                                                                              │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1645 in train                    │
│                                                                                                  │
│   1642 │   │   inner_training_loop = find_executable_batch_size(                                 │
│   1643 │   │   │   self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size  │
│   1644 │   │   )                                                                                 │
│ ❱ 1645 │   │   return inner_training_loop(                                                       │
│   1646 │   │   │   args=args,                                                                    │
│   1647 │   │   │   resume_from_checkpoint=resume_from_checkpoint,                                │
│   1648 │   │   │   trial=trial,                                                                  │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/accelerate/utils/memory.py:132 in decorator              │
│                                                                                                  │
│   129 │   │   │   if batch_size == 0:                                                            │
│   130 │   │   │   │   raise RuntimeError("No executable batch size found, reached zero.")        │
│   131 │   │   │   try:                                                                           │
│ ❱ 132 │   │   │   │   return function(batch_size, *args, **kwargs)                               │
│   133 │   │   │   except Exception as e:                                                         │
│   134 │   │   │   │   if should_reduce_batch_size(e):                                            │
│   135 │   │   │   │   │   gc.collect()                                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1916 in _inner_training_loop     │
│                                                                                                  │
│   1913 │   │   │   │   rng_to_sync = True                                                        │
│   1914 │   │   │                                                                                 │
│   1915 │   │   │   step = -1                                                                     │
│ ❱ 1916 │   │   │   for step, inputs in enumerate(epoch_iterator):                                │
│   1917 │   │   │   │   total_batched_samples += 1                                                │
│   1918 │   │   │   │   if rng_to_sync:                                                           │
│   1919 │   │   │   │   │   self._load_rng_state(resume_from_checkpoint)                          │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:634 in __next__           │
│                                                                                                  │
│    631 │   │   │   if self._sampler_iter is None:                                                │
│    632 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   │
│    633 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   │
│ ❱  634 │   │   │   data = self._next_data()                                                      │
│    635 │   │   │   self._num_yielded += 1                                                        │
│    636 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and \                          │
│    637 │   │   │   │   │   self._IterableDataset_len_called is not None and \                    │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:678 in _next_data         │
│                                                                                                  │
│    675 │                                                                                         │
│    676 │   def _next_data(self):                                                                 │
│    677 │   │   index = self._next_index()  # may raise StopIteration                             │
│ ❱  678 │   │   data = self._dataset_fetcher.fetch(index)  # may raise StopIteration              │
│    679 │   │   if self._pin_memory:                                                              │
│    680 │   │   │   data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)            │
│    681 │   │   return data                                                                       │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:54 in fetch             │
│                                                                                                  │
│   51 │   │   │   │   data = [self.dataset[idx] for idx in possibly_batched_index]                │
│   52 │   │   else:                                                                               │
│   53 │   │   │   data = self.dataset[possibly_batched_index]                                     │
│ ❱ 54 │   │   return self.collate_fn(data)                                                        │
│   55                                                                                             │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/transformers/trainer_utils.py:706 in __call__            │
│                                                                                                  │
│   703 │                                                                                          │
│   704 │   def __call__(self, features: List[dict]):                                              │
│   705 │   │   features = [self._remove_columns(feature) for feature in features]                 │
│ ❱ 706 │   │   return self.data_collator(features)                                                │
│   707                                                                                            │
│                                                                                                  │
│ in __call__:37                                                                                   │
│                                                                                                  │
│   34 │   def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dic    │
│   35 │   │   # split inputs and labels since they have to be of different lenghts and need       │
│   36 │   │   # different padding methods                                                         │
│ ❱ 37 │   │   input_features = [{"input_values": feature["input_values"]} for feature in featu    │
│   38 │   │   label_features = [{"input_ids": feature["labels"]} for feature in features]         │
│   39 │   │                                                                                       │
│   40 │   │   batch = self.processor.pad(                                                         │
│                                                                                                  │
│ in <listcomp>:37                                                                                 │
│                                                                                                  │
│   34 │   def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dic    │
│   35 │   │   # split inputs and labels since they have to be of different lenghts and need       │
│   36 │   │   # different padding methods                                                         │
│ ❱ 37 │   │   input_features = [{"input_values": feature["input_values"]} for feature in featu    │
│   38 │   │   label_features = [{"input_ids": feature["labels"]} for feature in features]         │
│   39 │   │                                                                                       │
│   40 │   │   batch = self.processor.pad(                                                         │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 'input_values'



Why am I getting this error? What am I doing wrong?
But if I run the same code without LoraConfig(), it runs fine. Kindly point out the mistake.
Thanks

You will need to subclass PeftModelForSequenceClassification. You can try smth like below:

from peft import PeftModelForSequenceClassification, PeftConfig
import torch
from typing import Optional

class PeftModelForAudioClassification(PeftModelForSequenceClassification):
    def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
        super().__init__(model, peft_config, adapter_name)
        if self.modules_to_save is None:
            self.modules_to_save = {"classifier","score"}
        else:
            self.modules_to_save.update({"classifier","score"})
        for name, _ in self.base_model.named_children():
            if any(module_name in name for module_name in self.modules_to_save):
                self.cls_layer_name = name
                break
    
    def forward(self,
                input_values: Optional[torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        peft_config = self.active_peft_config
        if not peft_config.is_prompt_learning:
            return self.base_model(input_values, attention_mask, output_attentions, output_hidden_states, return_dict, labels)