with torch.set_grad_enabled(True):
model.train()
for batch in dataloader:
input = batch["input_val"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
out = model(input)
loss = loss_fn(out["logits"], labels)
loss_epoch += loss.detach().cpu().item()
num_steps += input.shape[0]
loss.backward()
optimizer.step()
here model is
model = AutoModelForAudioClassification.from_pretrained(
"facebook/wav2vec2-conformer-rope-large", num_labels=num_labels, label2id=label2id, id2label=id2label
).to(device)
I got the exact same results as I got with an untrained model