Multi-label text classification error

I’m trying to use SetFit for multi-label text classification. I can’t get either of the examples in the the Huggingface docs to work as I get the error below. I put the code together from the hugging face docs. Any pointers for overcoming this would be greatly appreciated.

Here is the doc page:

Here is the code:
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
from datasets import load_dataset

model = SetFitModel.from_pretrained(“BAAI/bge-small-en-v1.5”,
multi_target_strategy=“multi-output”)

model = SetFitModel.from_pretrained(

“BAAI/bge-small-en-v1.5”,

multi_target_strategy=“one-vs-rest”

use_differentiable_head=True,

head_params={“out_features”: 5}

)

dataset = load_dataset(“SetFit/emotion”)

train_dataset = sample_dataset(dataset[“train”], label_column=“label”, num_samples=32)

test_dataset = dataset[“test”]

Preparing the training arguments

args = TrainingArguments(
batch_size=(32, 16),
num_epochs=(3, 8),
end_to_end=True,
body_learning_rate=(2e-5, 5e-6),
head_learning_rate=2e-3,
l2_weight=0.01,
)

Preparing the trainer

trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
)

trainer.train()

I get this error when trying to train.

Traceback (most recent call last):
File “/Users/wc/dev/newco/nlp-test/play.py”, line 59, in
trainer.train()
File “/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/setfit/trainer.py”, line 410, in train
self.train_embeddings(*full_parameters, args=args)
File “/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/setfit/trainer.py”, line 443, in train_embeddings
train_dataloader, loss_func, batch_size = self.get_dataloader(
^^^^^^^^^^^^^^^^^^^^
File “/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/setfit/trainer.py”, line 503, in get_dataloader
data_sampler = ContrastiveDataset(
^^^^^^^^^^^^^^^^^^^
File “/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/setfit/sampler.py”, line 66, in init
self.generate_multilabel_pairs()
File “/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/setfit/sampler.py”, line 100, in generate_multilabel_pairs
if any(np.logical_and(label, label)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'numpy.bool
’ object is not iterable