Wav2Vec2: How to correct for nan in training and validation loss

Hi,
I’m using Wav2Vec2ForCTC.from_pretrained(“facebook/wav2vec2-base”) to fine-tune on a English language medical translation dataset, which is about 6GB. It is similar to the Timit_ASR dataset, with the exception that the wav files are in 48KHz. I’m following the example show in this notebook: Fine-Tune Wav2Vec2 for English ASR in Hugging Face with 🤗 Transformers
Thank you @patrickvonplaten for an excellent illustration.

My issue is that that the training loss and validation loss show either nan or inf, and the WER does not decrease.

/usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py:134: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

 [5895/5895 24:03, Epoch 1/1]
Step	Training Loss	Validation Loss	Wer	Runtime	Samples Per Second
500	4.942800	nan	1.000000	20.818400	18.301000
1000	3.023500	nan	1.000000	20.785500	18.330000
1500	inf	3.477225	1.000000	20.757200	18.355000
2000	nan	3.492772	1.000000	20.922500	18.210000
2500	nan	nan	1.000000	20.817700	18.302000
3000	nan	nan	1.000000	20.788900	18.327000
3500	nan	nan	1.000000	20.850800	18.273000
4000	nan	nan	1.000000	20.928300	18.205000
4500	nan	3.214458	1.000000	20.900100	18.230000
5000	nan	nan	1.000000	20.907300	18.223000
5500	nan	3.297620	1.000000	20.979900	18.160000
TrainOutput(global_step=5895, training_loss=nan, metrics={'train_runtime': 1444.0059, 'train_samples_per_second': 4.082, 'total_flos': 2.494079663135328e+17, 'epoch': 1.0, 'init_mem_cpu_alloc_delta': 1760702464, 'init_mem_gpu_alloc_delta': 377847808, 'init_mem_cpu_peaked_delta': 0, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 2907852800, 'train_mem_gpu_alloc_delta': 1116327936, 'train_mem_cpu_peaked_delta': 847519744, 'train_mem_gpu_peaked_delta': 2587473920})

The only exceptions that I have to the notebook code is that I’m using torchaudio instead of soundfile and I’m downsampling to 16 KHz. The down sampling seems to be working correctly because the audio plays well when checking. That code is here:

import torchaudio

def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["file_name"])
    batch["speech"] = speech_array[0].numpy()
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["phrase"]
    return batch

dataset = dataset.map(speech_file_to_array_fn, remove_columns=dataset.column_names["train"], writer_batch_size=32, num_proc=4)

import librosa
import numpy as np

def resample(batch):
    batch["speech"] = librosa.resample(np.asarray(batch["speech"]), 48_000, 16_000)
    batch["sampling_rate"] = 16_000
    return batch

dataset = dataset.map(resample, writer_batch_size=32, num_proc=4)

I’ve tried to use different learning rates. A couple of the 500 increment steps in the above table actually showed a loss number instead of nan. But then subsequent losses were nan.

I also tried to follow the example in the Fine-tuning XLSR-Wav2Vec2 for Multi-Lingual ASR with :hugs: Transformers notebook. That resulted in a different error saying that the blank was out of label range.

I think I have a vanishing/exploding gradient problem. Perhaps I should try changing learning rates and regularization like dropout. I’m not sure how to make it work in this notebook.

Any help would be appreciated.

The dataset is from here: Medical Speech, Transcription, and Intent | Kaggle

Notebook is run on Google Colab Pro.

Any help would be much appeciated.

Thanks,
Sidd

Hello,
I had the same issue, I can give my solution but I don’t know if it will work for you.
So, in the pytorch documentation (CTCLoss — PyTorch 1.8.1 documentation), it is said that

The alignment of input to target is assumed to be “many-to-one”, which limits the length of the target sequence such that it must be ≤ the input length.

Sometimes the predicted segments’ length were smaller than the true ones, hence I had “inf” and “nan” during the training. To fix this, you need to allow zero_infinity :

  • zero_infinity ( bool , optional ) – Whether to zero infinite losses and the associated gradients. Default: False Infinite losses mainly occur when the inputs are too short to be aligned to the targets.

You need to do that in your code :

model = Wav2Vec2ForCTC.from_pretrained(path_2_model)
model.config.ctc_zero_infinity = True

Hope it will solve your problem.

Best,

Omar

Hi Omar, first of all, thank you for your response. I did try your suggestion. Unfortunately, it does not solve the problem. -Sidd

@ssaran Did you eventually get around this issue? I’m also finding myself with inf/nan losses as I’m downstream training Wav2Vec2ForCTC models. :neutral_face:

Unfortunately, I did not. While I continued to use Wav2Vec2, I didn’t fine tune it to my dataset. I might try it again another time.

Aw, okay. Thanks for the reply though!

Have you tried to reduce your learning rate?

There was an issue with the shape of my data. It works fine after fixing the bug! But to add some info for someone else who might be having a similar problem, cleaning the speech before finetuning could help!

Can you please clarify what the issue with your data shapes was? I am also having the same issue and I can’t solve it.

It was a problem with the sampling rate in my case. I was mistakenly using too short segments of audio so the model was unable to learn anything. You might as well try cleaning your audio if it’s too noisy.

Thanks for answering. I have a few training instances that are <1 second duration. You think they might cause a problem?

I also ended up trying this previous suggestion

model.config.ctc_zero_infinity = True

and it seems to be working so far.

Thanks for answering. I have a few training instances that are <1 second duration. You think they might cause a problem?

That could be it. You should check the contents of short training instances and filter out ones that might be too short to contain any valid speech.

Will do, thank you.