Pretrain facebook/wav2vec2-base

Hello everyone!

I want to additionally pretrain facebook/wav2vec2-base on my custom data.

Currently I’m trying to use this script.

I’m running it with the next arguments (I edited the code a bit to use my personal datasets, so I omit the dataset related args):

accelerate launch \
    --audio_column_name="path" \
    --min_duration_in_seconds=0.0 \
    --max_duration_in_seconds=20.0 \
	--model_name_or_path="facebook/wav2vec2-base" \
	--output_dir="./wav2vec2-pretrained-demo" \
	--max_train_steps="200000" \
	--num_warmup_steps="32000" \
	--gradient_accumulation_steps="4" \
	--learning_rate="0.001" \
	--weight_decay="0.01" \
	--max_duration_in_seconds="20.0" \
	--min_duration_in_seconds="2.0" \
	--logging_steps="1" \
	--saving_steps="10000" \
	--per_device_train_batch_size="8" \
	--per_device_eval_batch_size="8" \
	--adam_beta1="0.9" \
	--adam_beta2="0.98" \
	--adam_epsilon="1e-06" \
	--gradient_checkpointing \

And I’m getting the next error:

ValueError: PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'

Please tell me, where can I find the code to pretrain the model from facebook/wav2vec2-base checkpoint?

CC @patrickvonplaten @anton-l

Hey @ifedorov,

Yeah sadly we currently don’t allow pretraining the do_stable_layer_norm=False config. Could you instead pretrain the model using the base config: patrickvonplaten/wav2vec2-base-v2 · Hugging Face

Hi @patrickvonplaten — I have run into the same issue as OP, and just wanted to clarify something.

I see that in patrickvonplaten/wav2vec2-base-v2 at main there is no pytorch_model.bin (unlike facebook/wav2vec2-base at main).

When you run additional pre-training using:

accelerate launch \
	# ...
	# ...

Does it automatically fetch the weights from facebook/wav2vec2-base (to do additional pre-training) or does it randomly initialize a new model with the config specified in patrickvonplaten/wav2vec2-base-v2?


Edit: ah, nevermind — found the answer in the

# initialize random model
model = Wav2Vec2ForPreTraining(config)