Can someone give me a simple example on how to train Wav2Vec2 for audio frame classification?

Can someone give me a simple example on how to train Wav2Vec2ForAudioFrameClassification? :slight_smile:
The docs only supply an example for inference and I’m new to data science…

2 Likes

The model has a expects the labels to only be 2d, regardless of the number of classes. For a binary problem, I prepared the labels as such:

check_labels = labels.long() # copy before modification.
labels = torch.nn.functional.one_hot(labels.long(), num_classes=2)

labels = labels.view(labels.shape[0], labels.shape[1]*labels.shape[2])

# this is the line used in the model.
num_classes=2
for i in range(len(check_labels)):
    check_against=labels[i]
    check_against =  torch.argmax(check_against.view(-1, num_classes), axis=1).long()
    torch.testing.assert_close(check_labels[i], check_against, rtol=1e-5, atol=1e-5)
1 Like