How to use PEFT's LoRA with Optimum's BetterTransformer?

The particular model I’m currently using is Wav2Vec 2.0. It seems like LoRA + Flash attention is not possible.

from transformers import AutoModel
from peft import get_peft_model, LoraConfig
from optimum.bettertransformer import BetterTransformer

model_id = "facebook/wav2vec2-xls-r-300m"

model = AutoModel.from_pretrained(model_id)
peft_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["k_proj", "q_proj"],
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

x = torch.randn(4, 16_000 * 5)

model(x)

The above runs ok.

The following do not work.

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model = BetterTransformer.transform(model)

x = torch.randn(4, 16_000 * 5)

model(x)

Error

Traceback (most recent call last):
  File "/home/ubuntu/code/voice-anti-spoofing/debug.py", line 29, in <module>
    model(x)
  File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/peft/peft_model.py", line 442, in forward
    return self.get_base_model()(*args, **kwargs)
  File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py", line 1568, in forward
    encoder_outputs = self.encoder(
  File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py", line 908, in forward
    layer_outputs = layer(
  File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/optimum/bettertransformer/models/encoder_models.py", line 1417, in forward
    raise NotImplementedError(
NotImplementedError: Training and Autocast are not implemented for BetterTransformer + Wav2Vec2. Please open an issue.

This is probably because optimum replaces Wav2Vec2 with another module, so the new module is not compatible with hacks applied by LoRA. (https://github.com/huggingface/optimum/blob/05d20df3e6602e26d01cf3994a108de5b097a719/optimum/bettertransformer/models/encoder_models.py#L1315). The error is not related, since if I remove LoRA, it still runs fine.

Try to create a BetterTransformer model before applying LoRA also does not work, since Optimum will patch the model, so the nn.Linear modules are no longer available.

I wanted to create an issue on Github, but this does not seem to fit any repo in particular, so I posted here instead. Do advise what I can for to make LoRA + Flash attention possible. And if it is not currently not possible, direct me where to open an issue.

Thank you

1 Like