Apply batched zero shot classification on HuggingFace datasets object

Hi,

UPDATE: notebook to reproduce: https://colab.research.google.com/drive/1t-ApjHqdSo90NoXSJ5baeh7h-gJx8bLt?usp=sharing

I have a large amount of unlabeled texts, stored as a Pandas dataframe. So just a single column called “text”.

I’d like to apply zero-shot classification on all these texts in a batched way using HuggingFace Datasets’ .map(function, batched=True) functionality. I defined the function that I want to apply on batches as follows:

def zero_shot_classify_sequences(examples, threshold=0.5):

    # first, send batch of texts through pipeline
    texts = examples['text']
    outputs = classifier(texts, candidate_labels, multi_label=True)
    # next, for each output:
    final_outputs = []
    for output in outputs:
        # create dictionary (predicted_labels, confidence)
        final_output = {}
        for label, score in zip(output['labels'], output['scores']):
            if score > threshold:
               final_output[label] = score
        final_outputs.append(final_output)

    assert len(final_outputs) == len(texts)
    # set final outputs
    examples['predicted_labels'] = final_outputs

    return examples

The candidate labels are defined outside of this function.
In other words, I’d like to add a new column “predicted_labels”, which, for a batch of texts, should be a list of dictionaries (each dictionary mapping labels to confidence values for a given text - only those for which the confidence value > 0.5). However, when I do updated_dataset = dataset.map(zero_shot_classify_sequences, batched=True, batch_size=10), the output does not look like I’d expect. For a given text, I get the following:

'predicted_labels': {'Delivery & fulfilment technology': None,
  'Novel processing techniques & Equipments': None,
  'Plant-based': None,
  'Retail tech': None}

This should not be the case. In case none of the confidence values is higher than the threshold of 0.5, then the dictionary of “predicted labels” should be empty for that given example.
It probably has to do with the fact that a list of dictionaries is not supported by Apache Arrow? Or is it?

cc @lhoestq

Hi ! dictionaries are supported as long as they have the same keys.
Otherwise the missing keys are attributed a value of None.

In your case since you only keep certain labels depending on the scores, you can instead use lists:

final_output["labels"] = [label for label, score in zip(output['labels'], output['scores']) if score > threshold]
final_output["scores"] = [score for score in output['scores'] if score > threshold]

Thanks for the reply. However, when using this, I’m getting the following error:

2 updated_dataset_batched = unlabeled_data.map(zero_shot_classify_sequences_batched_labels_and_scores, 
----> 3                                              batched=True, batch_size=2)

~/(...)/lib/python3.6/site-packages/datasets/arrow_dataset.py in map(self, function, with_indices, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint)
   1307                 fn_kwargs=fn_kwargs,
   1308                 new_fingerprint=new_fingerprint,
-> 1309                 update_data=update_data,
   1310             )
   1311         else:

~/(...)/lib/python3.6/site-packages/datasets/arrow_dataset.py in wrapper(*args, **kwargs)
    202         }
    203         # apply actual function
--> 204         out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    205         datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    206         # re-apply format to the output

~/(...)/lib/python3.6/site-packages/datasets/fingerprint.py in wrapper(*args, **kwargs)
    335             # Call actual function
    336 
--> 337             out = func(self, *args, **kwargs)
    338 
    339             # Update fingerprint of in-place transforms + update in-place history of transforms

~/projecten/datascouts/env_datascouts/lib/python3.6/site-packages/datasets/arrow_dataset.py in _map_single(self, function, with_indices, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset, update_data)
   1580                     if update_data:
   1581                         batch = cast_to_python_objects(batch)
-> 1582                         writer.write_batch(batch)
   1583             if update_data:
   1584                 writer.finalize()  # close_stream=bool(buf_writer is None))  # We only close if we are writing in a file

~/(...)/lib/python3.6/site-packages/datasets/arrow_writer.py in write_batch(self, batch_examples, writer_batch_size)
    274             typed_sequence = TypedSequence(batch_examples[col], type=col_type, try_type=col_try_type)
    275             typed_sequence_examples[col] = typed_sequence
--> 276         pa_table = pa.Table.from_pydict(typed_sequence_examples)
    277         self.write_table(pa_table, writer_batch_size)
    278 

~/(...)/lib/python3.6/site-packages/pyarrow/table.pxi in pyarrow.lib.Table.from_pydict()

~/(...)/lib/python3.6/site-packages/pyarrow/array.pxi in pyarrow.lib.asarray()

~/(...)/lib/python3.6/site-packages/pyarrow/array.pxi in pyarrow.lib.array()

~/(...)/lib/python3.6/site-packages/pyarrow/array.pxi in pyarrow.lib._handle_arrow_array_protocol()

~/(...)/lib/python3.6/site-packages/datasets/arrow_writer.py in __arrow_array__(self, type)
     95                 out = pa.ExtensionArray.from_storage(type, pa.array(self.data, type.storage_dtype))
     96             else:
---> 97                 out = pa.array(self.data, type=type)
     98             if trying_type and out[0].as_py() != self.data[0]:
     99                 raise TypeError(

~/(...)/lib/python3.6/site-packages/pyarrow/array.pxi in pyarrow.lib.array()

~/(...)/lib/python3.6/site-packages/pyarrow/array.pxi in pyarrow.lib._sequence_to_array()

~/(...)/lib/python3.6/site-packages/pyarrow/error.pxi in pyarrow.lib.pyarrow_internal_check_status()

~/(...)/lib/python3.6/site-packages/pyarrow/error.pxi in pyarrow.lib.check_status()

ArrowInvalid: Invalid null value

My function looks as follows:

# let's use a higher threshold
def zero_shot_classify_sequences_batched(examples, threshold=0.6):
    # first, send text through pipeline
    texts = examples['text']
    outputs = classifier(texts, candidate_labels, multi_class=True)
    # next, for each output:
    final_outputs = []
    for output in outputs:
        final_output = {}
        final_output["labels"] = [label for label, score in zip(output['labels'], output['scores']) if score > threshold]
        final_output["scores"] = [score for score in output['scores'] if score > threshold]
        final_outputs.append(final_output)
    
    assert len(final_outputs) == len(texts)
    # set final outputs
    examples['predicted_labels'] = final_outputs
    
    return examples

From what I found in the code, the format transformation occuring in the dataset.map function is driven by the transmit_format(func) wrapper defined in datasets.arrow_dataset.py [184:219].

My guess is that this is necessary so that methods such as .to_pandas() or .to_csv() can be applied on the mapped dataset.

I used these functions with your toy dataset:

def zero_shot_classify_sequences_v2(examples, threshold=0.5):
  with classifier.device_placement():
    outputs = classifier(examples['text'], candidate_labels, multi_label=True)

  final_outputs = []
  labels_set = set()
  for result in outputs:
    try:
      final_output = dict([(result['labels'][i], score) for i, score in enumerate(result['scores']) if score > threshold])
      labels_set.update(list(final_output.keys()))
    except TypeError:
      final_output = {}

    print("final_output before return: \n", final_output)
    final_outputs.append(final_output)

  print("Set of labels in threshold with this batch: ", len(labels_set), labels_set)
  examples['predicted_labels'] = final_outputs

  return examples


def print_dataset(dataset):
  for i, results in enumerate(dataset):
    pp((i, results))

Colab settings:

transformers: 4.5.0
python: sys.version_info(major=3, minor=7, micro=10, releaselevel='final', serial=0)

I found a few quirks with the map function:

A. Batching problems

  1. Using zero_shot_classify_sequences_v2, setting batch=False returns empty score dicts.
  2. Let b = the batch size, with N=4 (as in your toy example): if b={1,3} (i.e. odd and < N), this error is raised: ArrowInvalid: Column 1 named text expected length 3 but got length 1.

B. Reformatting problems:

  1. The columns reformatting (in the map transformation) is dependent on the batch-size. Examples:
# 1. batch_size=2
updated_dataset_v2b2 = dataset.map(zero_shot_classify_sequences_v2, batched=True, batch_size=2)
#last printed line:
>>>Set of labels in threshold with this batch:  7 {'Plant-based', 'Vertical & Indoor farming', 'Functional foods & Other alternative ingredients', 'B2B marketplaces', 'Meal planning & Nutrition', 'Meal replacements & Supplements', 'Food safety & Quality'}

# 2. batch_size=5
updated_dataset_v2b5 = dataset.map(zero_shot_classify_sequences_v2, batched=True, batch_size=5)
#last printed line:
>>>Set of labels in threshold with this batch:  13 {'Plant-based', 'Vertical & Indoor farming', 'Wine tech', 'Retail tech', 'B2B marketplaces', 'Supply chain monitoring & traceability', 'Functional foods & Other alternative ingredients', 'Ag marketplace', 'Meal planning & Nutrition', 'Delivery & fulfilment technology', 'Meal replacements & Supplements', 'Food safety & Quality', 'Drinks'}
  1. Missing new columns: the union over the labels appearing in the dict is somehow missing the entries from the ultimate batch if b < N. This is why these 6 predicted labels are returned when b=2:
{'Ag marketplace',
 'Delivery & fulfilment technology',
 'Drinks',
 'Retail tech',
 'Supply chain monitoring & traceability',
 'Wine tech'}

But if b >= N, the resulting (common) columns are correct (i.e. 13 for b={4,5}):

{'Ag marketplace',
 'B2B marketplaces',
 'Delivery & fulfilment technology',
 'Drinks',
 'Food safety & Quality',
 'Functional foods & Other alternative ingredients',
 'Meal planning & Nutrition',
 'Meal replacements & Supplements',
 'Plant-based',
 'Retail tech',
 'Supply chain monitoring & traceability',
 'Vertical & Indoor farming',
 'Wine tech'}

It smells like a bug to me.