Create custom data_collator for Huggingface Trainer

I need to create a custom data_collator for finetuning with Huggingface Trainer API.

HuggingFace offers DataCollatorForWholeWordMask for masking whole words within the sentences with a given probability.

model_ckpt    = "vinai/bertweet-base"
tokenizer     = AutoTokenizer.from_pretrained(model_ckpt, normalization=True)
data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm_probability=args.mlm_prob)

I am using the collator as Trainer argument:

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset['train'],
        eval_dataset=None,
        data_collator=data_collator,
    )
    trainer.train()

But in my usecase, a sample input looks like: <sent1>.<sent2>. I want to mask the tokens only in <sent2> and not in <sent1>. How can I go about it? Any pointers on getting started with it are also welcome.

Subclass DataCollatorForWholeWordMask and override the torch_mask_tokens function.