Improving performance of Wav2Vec2 fine tuning with word piece vocabulary

Hello,

I’m fine tuning XLSR-Wav2Vec2 on a 200+ hours of a speech in a language not in the original pertaining.

The training progresses nicely, however when it reaches about 40 WER it starts to overfit (WER doesn’t progress much and train loss decreases while eval loss is going up).

I’ve tried increasing some params of the SpecAugment, but it only helped a bit.

I’ve noticed that using the Speechbrain lib implementation I’m getting a bit better results (on the expense of training stability) and was wondering if it is due to a larger vocabulary they use there. Does anyone tried to use a tokenizer with a vocabulary that contains subwords and words in addition to characters? I could’t find any experiment that uses it with Huggingface transformers W2V2.

I see in the Wav2Vec 2 paper they say that:

We expect performance gains by switching to a seq2seq architecture and a
word piece vocabulary.
https://arxiv.org/pdf/2006.11477.pdf

Any suggestions on how to do that with Huggingface Transformers?

P.S. my dataset is noisy and not super clean.

Any help or suggestion will be very helpful.

Samuel

Not sure how I’d switch to a seq2seq architecture, but for word piece, I think you just need to change the vocab passed to the Wav2Vec2CTCTokenizer. Instead of the individual alphabet characters used for the vocab in the XLSR example, you’d need to use the wordpiece/BPE algorithm on your language text data and pass that through.

Thanks for the answer!
Any code examples or ideas on how to use word piece tokenizer easily? I understand I’ll need to basically override most of the functions in transformers/models/wav2vec2/tokenization_wav2vec2.py

1 Like

you can look into sentencepiece.
Hope that helps!

This can be accomplished by using the BertTokenizer and setting vocab_size to 30522. Keep in mind that you don’t want to use the existing lm_head weights in the Wav2Vec2ForCTC checkpoint though. I did this with the TensorFlow version, but I don’t think there is a vocab limit on the PyTorch ctc loss either.