I am using a RTX 4090. I would imagine that sequence classification would be rather fast on it but apparently no. It takes MORE than 20 mins to run on a single sequence of 4k tokens (may or may not contain padding tokens). So calling the predict function on a single row of data takes more than 40 mins. This would just take forever to run on the test dataset i have.
I am using the code below. The model is Mistral 7b.
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model_id = "my trained seq classification model based on Mistral"
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
num_labels=1,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
def predict(row):
prompt = row.question
chosen = f"Question: {prompt}\nAnswer: {row.chosen.strip()}"
rejected = f"Question: {prompt}\nAnswer: {row.rejected.strip()}"
with torch.no_grad():
rewards_chosen = model(
**tokenizer(chosen, return_tensors='pt')
).logits
print('reward chosen is ', rewards_chosen)
rewards_rejected = model(
**tokenizer(rejected, return_tensors='pt')
).logits
print('reward rejected is ', rewards_rejected)
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
print(f"loss is {loss}")
return (rewards_chosen.item(), rewards_rejected.item(), loss, rewards_chosen>rewards_rejected)
What could be the problem that it takes so long to run the inference? And how do i fix it?