How to use whole word masking data_collator?

In the course about fine-tuning a masked language model:

It introduced a function to do whole-word masking:

import collections
import numpy as np

from transformers import default_data_collator

wwm_probability = 0.2

def whole_word_masking_data_collator(features):
    for feature in features:
        word_ids = feature.pop("word_ids")

        # Create a map between words and corresponding token indices
        mapping = collections.defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = feature["input_ids"]
        labels = feature["labels"]
        new_labels = [-100] * len(labels)
        for word_id in np.where(mask)[0]:
            word_id = word_id.item()
            for idx in mapping[word_id]:
                new_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id

    return default_data_collator(features)

However, the course didn’t tell us how to use it!

I tried to pass this function to the Trainer as follows:

trainer = Trainer(

and I got the following error:

The following columns in the training set  don't have a corresponding argument in `DistilBertForMaskedLM.forward` and have been ignored: word_ids.
KeyError                                  Traceback (most recent call last)
/var/folders/ts/ft1kkj55399gmd5c5cr535dm0000gn/T/ipykernel_2318/ in <module>
      6     data_collator=whole_word_masking_data_collator,
      7 )
----> 8 trainer.train()

~/miniforge3/envs/torch/lib/python3.9/site-packages/transformers/ in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1373             step = -1
-> 1374             for step, inputs in enumerate(epoch_iterator):
   1376                 # Skip past any already trained steps if resuming training

~/miniforge3/envs/torch/lib/python3.9/site-packages/torch/utils/data/ in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

~/miniforge3/envs/torch/lib/python3.9/site-packages/torch/utils/data/ in _next_data(self)
    559     def _next_data(self):
    560         index = self._next_index()  # may raise StopIteration
--> 561         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    562         if self._pin_memory:
    563             data = _utils.pin_memory.pin_memory(data)

~/miniforge3/envs/torch/lib/python3.9/site-packages/torch/utils/data/_utils/ in fetch(self, possibly_batched_index)
     45         else:
     46             data = self.dataset[possibly_batched_index]
---> 47         return self.collate_fn(data)

/var/folders/ts/ft1kkj55399gmd5c5cr535dm0000gn/T/ipykernel_2318/ in whole_word_masking_data_collator(features)
      7 def whole_word_masking_data_collator(features):
      8     for feature in features:
----> 9         word_ids = feature.pop("word_ids")
     10         """

KeyError: 'word_ids'

It is very strange, since I checked my input dataset where the key “word_ids” did exists:

    features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
    num_rows: 10000

I don’t know what is going wrong, could any one help me?

You have to add remove_unused_columns=False in the TrainingArguments. Otherwise, the Trainer automatically remove the word_ids column since it’s not expected by the model.

Thanks @sgugger for the quick reply! It works!

btw, since the model won’t expect the word_ids, how will the model process it when I set remove_unused_columns=False ?
I’m sorry its too hard for me to read the source code directly to figure it out~

The model won’t receive it since your data collator removes it with the pop.

Oh I see. Thanks!