The particular model I’m currently using is Wav2Vec 2.0. It seems like LoRA + Flash attention is not possible.
from transformers import AutoModel
from peft import get_peft_model, LoraConfig
from optimum.bettertransformer import BetterTransformer
model_id = "facebook/wav2vec2-xls-r-300m"
model = AutoModel.from_pretrained(model_id)
peft_config = LoraConfig(
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["k_proj", "q_proj"],
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
x = torch.randn(4, 16_000 * 5)
model(x)
The above runs ok.
The following do not work.
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model = BetterTransformer.transform(model)
x = torch.randn(4, 16_000 * 5)
model(x)
Error
Traceback (most recent call last):
File "/home/ubuntu/code/voice-anti-spoofing/debug.py", line 29, in <module>
model(x)
File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/peft/peft_model.py", line 442, in forward
return self.get_base_model()(*args, **kwargs)
File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py", line 1568, in forward
encoder_outputs = self.encoder(
File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py", line 908, in forward
layer_outputs = layer(
File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/mambaforge/envs/dev/lib/python3.10/site-packages/optimum/bettertransformer/models/encoder_models.py", line 1417, in forward
raise NotImplementedError(
NotImplementedError: Training and Autocast are not implemented for BetterTransformer + Wav2Vec2. Please open an issue.
This is probably because optimum replaces Wav2Vec2 with another module, so the new module is not compatible with hacks applied by LoRA. (https://github.com/huggingface/optimum/blob/05d20df3e6602e26d01cf3994a108de5b097a719/optimum/bettertransformer/models/encoder_models.py#L1315). The error is not related, since if I remove LoRA, it still runs fine.
Try to create a BetterTransformer model before applying LoRA also does not work, since Optimum will patch the model, so the nn.Linear
modules are no longer available.
I wanted to create an issue on Github, but this does not seem to fit any repo in particular, so I posted here instead. Do advise what I can for to make LoRA + Flash attention possible. And if it is not currently not possible, direct me where to open an issue.
Thank you