Query about group_texts in run_mlm_no_trainer.py

The following is a snippet from transformers/run_mlm_no_trainer.py at eb5bdcdfa51f743887ee1d9c7f230444d7a8b23c · huggingface/transformers · GitHub

# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= max_seq_length:
        total_length = (total_length // max_seq_length) * max_seq_length
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        for k, t in concatenated_examples.items()
    }
    return result

We use Dataset.map() on the tokenized dataset with the above group_texts function. So the variable concatenated_examples in the above function when decoded would look something like [CLS]sentence1[SEP][CLS]sentence2[SEP][CLS]sentence3[SEP]…so on

and when we are then splitting this into chunks of max_len, some of the inputs are not even guaranteed to have [CLS] in the beginning and [SEP] at the end…

So, shouldn’t the above snippet be modified into something like below?

def group_texts(examples):
    # Concatenate all texts.
    max_num_tokens=max_seq_length-2#to account for [CLS] and [SEP]
    concatenated_examples = {k: list(chain(*[l[1:-1] for l in examples[k]])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= max_num_tokens:
        total_length = (total_length // max_num_tokens) * max_num_tokens
    # Split by chunks of max_len.
    EMPTY=tokenizer("", return_special_tokens_mask=True)
    result = {
        k: [[EMPTY[k][0]]+t[i : i + max_num_tokens]+[EMPTY[k][1]] for i in range(0, total_length, max_num_tokens)]
        for k, t in concatenated_examples.items()
    }
    return result