Can someone give me a simple example on how to train Wav2Vec2ForAudioFrameClassification?
The docs only supply an example for inference and I’m new to data science…
2 Likes
The model has a expects the labels to only be 2d, regardless of the number of classes. For a binary problem, I prepared the labels as such:
check_labels = labels.long() # copy before modification.
labels = torch.nn.functional.one_hot(labels.long(), num_classes=2)
labels = labels.view(labels.shape[0], labels.shape[1]*labels.shape[2])
# this is the line used in the model.
num_classes=2
for i in range(len(check_labels)):
check_against=labels[i]
check_against = torch.argmax(check_against.view(-1, num_classes), axis=1).long()
torch.testing.assert_close(check_labels[i], check_against, rtol=1e-5, atol=1e-5)
1 Like