Hi, guys, I am reading the code about doing whole word mask pretraining.
Here is some code pieces of function _whole_word_mask function from DataCollatorForWholeWordMask. transformers/data_collator.py at 3f936df66287f557c6528912a9a68d7850913b9b · huggingface/transformers · GitHub
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
"""
Get 0/1 labels for masked tokens with whole word mask proxy
"""
if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
warnings.warn(
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
"Please refer to the documentation for more information."
)
cand_indexes = []
for i, token in enumerate(input_tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_lms.append(index)
if len(covered_indexes) != len(masked_lms):
raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
return mask_labels
where the input tokens will be something like [‘[CLS]’, ‘i’, ‘love’, ‘n’, ‘##lp’, ‘[SEP]’] with a origin input
“i love nlp” and the function will select the word to be masked randomly with the whole word mask principle.
My question how could the variable is_any_index_covered be set to True. With my understanding, there should not be any duplicate index in different sub_list of cand_indexes, so how could a index get covered more than one time, could someone give me an example? Thanks