Wav2VecForPreTraining - Not able to run trainer.train()

I am trying to use Wav2VecForPreTraining to train the model from scratch on own audio dataset.

from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining, TrainingArguments, Trainer

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")


use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'
fp16 = True if use_cuda else False

model = model.to(device)

logstep = 100

training_args = TrainingArguments(
      output_dir="./",
      group_by_length=True,
      per_device_train_batch_size=2,
      evaluation_strategy="steps",
      num_train_epochs=35,
      fp16=fp16,
      save_steps=2100,
      eval_steps=logstep,
      logging_steps=logstep,
      learning_rate=1e-4,
      weight_decay=0.005,
      warmup_steps=1000,
      report_to=None,
      save_total_limit=1,
    )

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=test_data,
    tokenizer=processor.feature_extractor,
)
trainer.train()

I get below error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-34-3435b262f1ae> in <module>
----> 1 trainer.train()

~/.conda/envs/torch/lib/python3.7/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1284                         tr_loss += self.training_step(model, inputs)
   1285                 else:
-> 1286                     tr_loss += self.training_step(model, inputs)
   1287                 self.current_flos += float(self.floating_point_ops(inputs))
   1288 

~/.conda/envs/torch/lib/python3.7/site-packages/transformers/trainer.py in training_step(self, model, inputs)
   1787         if self.use_amp:
   1788             with autocast():
-> 1789                 loss = self.compute_loss(model, inputs)
   1790         else:
   1791             loss = self.compute_loss(model, inputs)

~/.conda/envs/torch/lib/python3.7/site-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs)
   1821         else:
   1822             labels = None
-> 1823         outputs = model(**inputs)
   1824         # Save past state if it exists
   1825         # TODO: this needs to be fixed and made cleaner later.

~/.conda/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() got an unexpected keyword argument 'labels'

Input data is in below dictionary format

{'input_values': tensor([[-0.0075, -0.0095, -0.0085,  ..., -1.0926, -1.1881, -1.1047],
        [ 0.5310,  0.9788,  1.4064,  ..., -0.1375, -0.1230, -0.1085]]), 'labels': tensor([[   3,    6,   12,   13,   13,    1,   22,    1,   26,   24,   28,    1,
            0,    6,   10,    1,   25,    4,    1,    3,    6,   13,    1,    4,
           27,    9,    4,   14,   12,   25,    9,   13,   12,    1,   10,   24,
            1,    3,    6,   13,    1,   24, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100],
        [   6,   26,   21,   13,    1,   26,    1,   20,   12,   13,   26,    3,
            1,   28,   26,   19,    1,   10,    6,    1,   24,   10,    3,    1,
           26,    1,    7,   12,   10,    9,    2,   13,   11,    1,   28,   25,
           28,    1,   19,   10,   27,    1,   24,   13,   13,   28,    1,   11,
           13,    1,    3,   10,    1,    3,   12,   26,   24,    4,   16,   13,
           12,    1,   19,   10,   27,    1,   10,   21,   13,   12,    1,    3,
           10,    1,   28,   13,    2,    3,   26,    1,   28,   13,   24,    3,
           26,    2]])}

When I looked at trainer.py in transformers, I see that error is coming from compute_loss function. In this function, it seems I need to define label_smoother.

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss

I even tried below in compute_loss

labels = inputs.pop("labels")
outputs = model(**inputs)

This throws below error

~/.conda/envs/torch/lib/python3.7/site-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs)
   1823         print(".... ", self.label_smoother)
   1824         print(" >>> ", labels)
-> 1825         outputs = model(**inputs)
   1826         # Save past state if it exists
   1827         # TODO: this needs to be fixed and made cleaner later.

~/.conda/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/.conda/envs/torch/lib/python3.7/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py in forward(self, input_values, attention_mask, mask_time_indices, output_attentions, output_hidden_states, return_dict)
   1299             # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
   1300             preds = logits.transpose(0, 2).reshape(-1, logits.size(0))
-> 1301             target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
   1302             contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="sum")
   1303 

AttributeError: 'NoneType' object has no attribute 'long'

Could someone please guide here?

You can’t use this model with the Trainer as it does not compute the loss. The Trainer API is only compatible with models that compute the loss when they are provided with labels.

Could you please guide/provide a resource regarding how to train Wav2Vec2ForPreTraining model? Much appreciate your help.

As seen here, @patrickvonplaten is working on it. Expect a blog post soon.