Wav2Vec2: loss growing in training and validation after few epochs

Hi all, @patrickvonplaten,

I’m using Wav2Vec2ForCTC.from_pretrained(“facebook/wav2vec2-large-xlsr-53”) to fine-tune my own spanish dataset (I’m actually using it in other ASR framework). I’ve limited the dataset to 100 hours of records, 105.150 utterances in the range of 1 to 10 seconds.
I’m following the example from this notebook: Fine-tuning XLS-R for Multi-Lingual ASR with :hugs: Transformers by @patrickvonplaten.

My issue is that that the training loss and validation loss steadily decrease first few steps and then all metrics start to worsen:

My paremeters and config is:

from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xls-r-300m", 
    layerdrop=0.05,
    mask_time_prob=0.0075,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    mask_time_length=4,
    vocab_size=len(processor.tokenizer)
)
training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=100,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=1000,
  eval_steps=500,
  logging_steps=50,
  learning_rate=3e-4,
  warmup_steps=4000,
  save_total_limit=2,
  push_to_hub=False,
)

Interesting! Could maybe change the following:

mask_time_prob=0.0075,

to

mask_time_prob=0.75

and

mask_time_length=4,

to

mask_time_length=10,

Especially the mask_time_length=4 is quite unusual and I think it has been shown that this doesn’t really help.

It is a bit confusing though that your model doesn’t seem to be able to overfit. Could you also reduce warmup_steps to just 1000 and maybe set layerdrop=0.0?

Thank you for your reply @patrickvonplaten, but after some steps I get the following error when I set mask_time_length=10:

ValueError: Cannot take a larger sample than population when 'replace=False'
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-0a4dde6faf17> in <module>
      1 # test
----> 2 trainer.train()

/usr/local/lib/python3.6/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1330                         tr_loss_step = self.training_step(model, inputs)
   1331                 else:
-> 1332                     tr_loss_step = self.training_step(model, inputs)
   1333 
   1334                 if (

/usr/local/lib/python3.6/dist-packages/transformers/trainer.py in training_step(self, model, inputs)
   1889 
   1890         with self.autocast_smart_context_manager():
-> 1891             loss = self.compute_loss(model, inputs)
   1892 
   1893         if self.args.n_gpu > 1:

/usr/local/lib/python3.6/dist-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs)
   1921         else:
   1922             labels = None
-> 1923         outputs = model(**inputs)
   1924         # Save past state if it exists
   1925         # TODO: this needs to be fixed and made cleaner later.

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.6/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py in forward(self, input_values, attention_mask, output_attentions, output_hidden_states, return_dict, labels)
   1659             output_attentions=output_attentions,
   1660             output_hidden_states=output_hidden_states,
-> 1661             return_dict=return_dict,
   1662         )
   1663 

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.6/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py in forward(self, input_values, attention_mask, mask_time_indices, output_attentions, output_hidden_states, return_dict)
   1286         hidden_states, extract_features = self.feature_projection(extract_features)
   1287         hidden_states = self._mask_hidden_states(
-> 1288             hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
   1289         )
   1290 

/usr/local/lib/python3.6/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py in _mask_hidden_states(self, hidden_states, mask_time_indices, attention_mask)
   1233                 mask_length=self.config.mask_time_length,
   1234                 attention_mask=attention_mask,
-> 1235                 min_masks=self.config.mask_time_min_masks,
   1236             )
   1237             mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)

/usr/local/lib/python3.6/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py in _compute_mask_indices(shape, mask_prob, mask_length, attention_mask, min_masks)
    241         # get random indices to mask
    242         spec_aug_mask_idx = np.random.choice(
--> 243             np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
    244         )
    245 

mtrand.pyx in numpy.random.mtrand.RandomState.choice()

ValueError: Cannot take a larger sample than population when 'replace=False'

In this case, input_length=9 and mask_length=10 so np.arange(input_length - (mask_length - 1)) = []

@patrickvonplaten I think that this error was fixed in this issue: https://github.com/huggingface/transformers/issues/15366

1 Like