Sure, thanks for the help @joeddav
Sharing the code snippet below running on an example tweet.
TERMS - List of candidate labels
HYPOTHESES = ['This text is about '+x for x in TERMS] (List of labels in the proper template way)
BartTokenizer.from_pretrained('facebook/bart-large-mnli')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
-
Using the model w/o the pipeline:
-
Using the pipeline:
âââMethod to get the labels for a tweet based on threshold specifiedâââ
def get_labels_pipeline(tweet, threshold=THRESHOLD):
topics = []
results = classifier(tweet, TERMS, multi_class=True)
for idx, score in enumerate(results['scores']):
score = score*100
if score>=threshold:
topics.append((results['labels'][idx], np.round(score, 2)))
return topics
Example:
Text = âWest Bengal calls for Indian Armyâs support to restore essential infrastructure, services after Cyclone Amphan havoc CycloneAmphan Amphan AmphanUpdatesâ
W/o the pipeline - get_labels(text, threshold=50)
[(âresource availabilityâ, 50.59), (ârelief measuresâ, 85.47), (âinfrastructureâ, 80.32), (ârescueâ, 66.81), (ânews updatesâ, 93.95), (âgrievanceâ, 79.94)]
With the pipeline - get_labels_pipeline(text, threshold=50)
[(âinfrastructureâ, 98.93), (ârelief measuresâ, 95.18), (âgrievanceâ, 92.81), (ânews updatesâ, 83.83), (âpower supplyâ, 80.1), (âutilitiesâ, 76.64), (âsympathyâ, 75.98), (âwater supplyâ, 73.14), (ârescueâ, 70.47)]
Thanks for the help again!