SHAP Implementation In BERT

I’m trying to implement SHAP in BERT Model. I’ve been getting the same key error ‘label’ when I’m trying to pass the input text in shap explainer object. I’m using SHAP (0.44.1) , TORCH (2.2.1) , PYTHON (3.8)

 from transformers import pipeline
  from transformers import TokenClassificationPipeline,AutoTokenizer
  import shap
  from typing import Union, List

  model_checkpoint = "saved_model_test"
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
  pipe = pipeline("token-classification", model=model_checkpoint, stride=50, 
 aggregation_strategy="first")
  def score_and_visualize(text,shap_values):
      prediction = pipe(text)
      print(f"Model predictions are: {prediction}")
      shap.plots.text(shap_values)


  text = "India"   



  explainer = shap.Explainer(pipe)
  shap_values = explainer(text)

  score_and_visualize(text,shap_values)

And this is my error

   KeyError                                  Traceback (most recent call last)
   Cell In[27], line 38
        35 #pipe.set_labels(labels)
        37 explainer = shap.Explainer(pipe)
   --> 38 shap_values = explainer([text])
        40 score_and_visualize(text,shap_values)

   File ~\Desktop\Python\bert-env\Lib\site-packages\shap\explainers\_partition.py:129, in 
 PartitionExplainer.__call__(self,  
  max_evals, fixed_context, main_effects, error_bounds, batch_size, outputs, silent, *args)
       125 def __call__(self, *args, max_evals=500, fixed_context=None, main_effects=False, 
 error_bounds=False, 
  batch_size="auto",
       126              outputs=None, silent=False):
       127     """ Explain the output of the model on the given arguments.
       128     """
   --> 129     return super().__call__(
       130         *args, max_evals=max_evals, fixed_context=fixed_context, main_effects=main_effects, 
  error_bounds=error_bounds, batch_size=batch_size,
       131         outputs=outputs, silent=silent
       132     )

   File ~\Desktop\Python\bert-env\Lib\site-packages\shap\explainers\_explainer.py:267, in 
 Explainer.__call__(self, max_evals, 
  main_effects, error_bounds, batch_size, outputs, silent, *args, **kwargs)
       265     feature_names = [[] for _ in range(len(args))]
       266 for row_args in show_progress(zip(*args), num_rows, self.__class__.__name__+" explainer", silent):
   --> 267     row_result = self.explain_row(
        268         *row_args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds,
       269         batch_size=batch_size, outputs=outputs, silent=silent, **kwargs
       270     )
       271     values.append(row_result.get("values", None))
       272     output_indices.append(row_result.get("output_indices", None))

   File ~\Desktop\Python\bert-env\Lib\site-packages\shap\explainers\_partition.py:154, in PartitionExplainer.explain_row(self, 
  max_evals, main_effects, error_bounds, batch_size, outputs, silent, fixed_context, *row_args)
       152 # if not fixed background or no base value assigned then compute base value for a row
       153 if self._curr_base_value is None or not getattr(self.masker, "fixed_background", False):
   --> 154     self._curr_base_value = fm(m00.reshape(1, -1), zero_index=0)[0] # the zero index param tells the masked model 
   what the baseline is
       155 f11 = fm(~m00.reshape(1, -1))[0]
       157 if callable(self.masker.clustering):

   File ~\Desktop\Python\bert-env\Lib\site-packages\shap\utils\_masked_model.py:69, in MaskedModel.__call__(self, masks, 
  zero_index, batch_size)
        66         return self._full_masking_call(full_masks, zero_index=zero_index, batch_size=batch_size)
        68 else:
   --> 69     return self._full_masking_call(masks, batch_size=batch_size)

   File ~\Desktop\Python\bert-env\Lib\site-packages\shap\utils\_masked_model.py:146, in 
  MaskedModel._full_masking_call(self, masks, zero_index, batch_size)
        143         all_masked_inputs[i].append(v)
       145 joined_masked_inputs = tuple([np.concatenate(v) for v in all_masked_inputs])
   --> 146 outputs = self.model(*joined_masked_inputs)
        147 _assert_output_input_match(joined_masked_inputs, outputs)
      148 all_outputs.append(outputs)

   File ~\Desktop\Python\bert-env\Lib\site-packages\shap\models\_transformers_pipeline.py:36, in 
  TransformersPipeline.__call__(self, strings)
        34         val = [val]
        35     for obj in val:
   --> 36         output[i, self.label2id[obj["label"]]] = scipy.special.logit(obj["score"]) if self.rescale_to_logits else obj["score"]
        37 return output
     Key Error 'label'