Zero Shot Multi-label text classification on SageMaker

Greetings,

I have developed a script on my computer to do some zero shot multi-label text classification using xlm-roberta.
I want to reporduce my work on sagemaker using huggingface inference toolkit and I having some trouble doing so.

On local when i do the classification i do the following:

classifier = pipeline(model="joeddav/xlm-roberta-large-xnli", task="zero-shot-classification")

predictions = classifier(sequence_to_classify, candidate_labels, multi_label=True)

On sagemaker, I configure the model from the hub and launch a batch transform job for inference but i can’t seem to find the multi_label parameter in the following:

huggingface_model = HuggingFaceModel(
        transformers_version="4.17.0",
        pytorch_version="1.10.2",
        py_version="py38",
        env=hub,
        role=event['role'])

    bt_output_key = f"s3://{event['bucket']}/{event['output_prefix']}/{event['execution_id']}"

    hf_transformer = huggingface_model.transformer(
        instance_count=event["instance_count"],
        instance_type=event["instance_type"],
        output_path=bt_output_key,
        strategy="SingleRecord",
        max_concurrent_transforms=event["concurrent_transforms"],
    )

    hf_transformer.transform(
        data=event['input_s3_path'],
        content_type="application/json",
        split_type="Line",
        wait=False
    )

I looked in the environment variables list but I think Im missing some thing.
Thank you for your help.

It looks like you are “customizing” the default behavior of the pipeline with kwargs for multilabel for this you either need to provide a inference.py which is doing what you are doing locally or you need to modify your input jsonline file to includes those parameters, below is an example. You can also check out the example: notebooks/sagemaker-notebook.ipynb at main · huggingface/notebooks · GitHub

{"inputs": "VirginAmerica plus you've added commercials to the experience... tacky.",   "parameters": {"candidate_labels": ["refund", "legal", "faq"], "multi_label": True}}
{"inputs": "VirginAmerica I didn't today... Must mean I need to take another trip!", "parameters": {"candidate_labels": ["refund", "legal", "faq"], "multi_label": True}}