Zero shot classification with manual pytorch


I want to run zero shot classification task without using the transformers pipeline.

I have gone through the code snippet for facebook/bart-large-mnli model, but I am unsure on how to recreate this for more than 3 labels. My dataset has 52 labels.

from transformers import AutoModelForSequenceClassification, AutoTokenizer
nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')

premise = sequence
hypothesis = f'This example is {label}.'

# run through model pre-trained on MNLI
x = tokenizer.encode(premise, hypothesis, return_tensors='pt',
logits = nli_model([0]

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
prob_label_is_true = probs[:,1]

Kindly help in converting above example for 52 labels.