I’m trying to implement SHAP in BERT Model. I’ve been getting the same key error ‘label’ when I’m trying to pass the input text in shap explainer object. I’m using SHAP (0.44.1) , TORCH (2.2.1) , PYTHON (3.8)
from transformers import pipeline
from transformers import TokenClassificationPipeline,AutoTokenizer
import shap
from typing import Union, List
model_checkpoint = "saved_model_test"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
pipe = pipeline("token-classification", model=model_checkpoint, stride=50,
aggregation_strategy="first")
def score_and_visualize(text,shap_values):
prediction = pipe(text)
print(f"Model predictions are: {prediction}")
shap.plots.text(shap_values)
text = "India"
explainer = shap.Explainer(pipe)
shap_values = explainer(text)
score_and_visualize(text,shap_values)
And this is my error
KeyError Traceback (most recent call last)
Cell In[27], line 38
35 #pipe.set_labels(labels)
37 explainer = shap.Explainer(pipe)
--> 38 shap_values = explainer([text])
40 score_and_visualize(text,shap_values)
File ~\Desktop\Python\bert-env\Lib\site-packages\shap\explainers\_partition.py:129, in
PartitionExplainer.__call__(self,
max_evals, fixed_context, main_effects, error_bounds, batch_size, outputs, silent, *args)
125 def __call__(self, *args, max_evals=500, fixed_context=None, main_effects=False,
error_bounds=False,
batch_size="auto",
126 outputs=None, silent=False):
127 """ Explain the output of the model on the given arguments.
128 """
--> 129 return super().__call__(
130 *args, max_evals=max_evals, fixed_context=fixed_context, main_effects=main_effects,
error_bounds=error_bounds, batch_size=batch_size,
131 outputs=outputs, silent=silent
132 )
File ~\Desktop\Python\bert-env\Lib\site-packages\shap\explainers\_explainer.py:267, in
Explainer.__call__(self, max_evals,
main_effects, error_bounds, batch_size, outputs, silent, *args, **kwargs)
265 feature_names = [[] for _ in range(len(args))]
266 for row_args in show_progress(zip(*args), num_rows, self.__class__.__name__+" explainer", silent):
--> 267 row_result = self.explain_row(
268 *row_args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds,
269 batch_size=batch_size, outputs=outputs, silent=silent, **kwargs
270 )
271 values.append(row_result.get("values", None))
272 output_indices.append(row_result.get("output_indices", None))
File ~\Desktop\Python\bert-env\Lib\site-packages\shap\explainers\_partition.py:154, in PartitionExplainer.explain_row(self,
max_evals, main_effects, error_bounds, batch_size, outputs, silent, fixed_context, *row_args)
152 # if not fixed background or no base value assigned then compute base value for a row
153 if self._curr_base_value is None or not getattr(self.masker, "fixed_background", False):
--> 154 self._curr_base_value = fm(m00.reshape(1, -1), zero_index=0)[0] # the zero index param tells the masked model
what the baseline is
155 f11 = fm(~m00.reshape(1, -1))[0]
157 if callable(self.masker.clustering):
File ~\Desktop\Python\bert-env\Lib\site-packages\shap\utils\_masked_model.py:69, in MaskedModel.__call__(self, masks,
zero_index, batch_size)
66 return self._full_masking_call(full_masks, zero_index=zero_index, batch_size=batch_size)
68 else:
--> 69 return self._full_masking_call(masks, batch_size=batch_size)
File ~\Desktop\Python\bert-env\Lib\site-packages\shap\utils\_masked_model.py:146, in
MaskedModel._full_masking_call(self, masks, zero_index, batch_size)
143 all_masked_inputs[i].append(v)
145 joined_masked_inputs = tuple([np.concatenate(v) for v in all_masked_inputs])
--> 146 outputs = self.model(*joined_masked_inputs)
147 _assert_output_input_match(joined_masked_inputs, outputs)
148 all_outputs.append(outputs)
File ~\Desktop\Python\bert-env\Lib\site-packages\shap\models\_transformers_pipeline.py:36, in
TransformersPipeline.__call__(self, strings)
34 val = [val]
35 for obj in val:
--> 36 output[i, self.label2id[obj["label"]]] = scipy.special.logit(obj["score"]) if self.rescale_to_logits else obj["score"]
37 return output
Key Error 'label'