The following is a snippet from transformers/run_mlm_no_trainer.py at eb5bdcdfa51f743887ee1d9c7f230444d7a8b23c · huggingface/transformers · GitHub
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= max_seq_length:
total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len.
result = {
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
for k, t in concatenated_examples.items()
}
return result
We use Dataset.map() on the tokenized dataset with the above group_texts function. So the variable concatenated_examples in the above function when decoded would look something like [CLS]sentence1[SEP][CLS]sentence2[SEP][CLS]sentence3[SEP]…so on
and when we are then splitting this into chunks of max_len, some of the inputs are not even guaranteed to have [CLS] in the beginning and [SEP] at the end…
So, shouldn’t the above snippet be modified into something like below?
def group_texts(examples):
# Concatenate all texts.
max_num_tokens=max_seq_length-2#to account for [CLS] and [SEP]
concatenated_examples = {k: list(chain(*[l[1:-1] for l in examples[k]])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= max_num_tokens:
total_length = (total_length // max_num_tokens) * max_num_tokens
# Split by chunks of max_len.
EMPTY=tokenizer("", return_special_tokens_mask=True)
result = {
k: [[EMPTY[k][0]]+t[i : i + max_num_tokens]+[EMPTY[k][1]] for i in range(0, total_length, max_num_tokens)]
for k, t in concatenated_examples.items()
}
return result