How can I pretrain T5 model?

Hi, I am trying to pretrain a T5 model from scratch but couldn’t find a proper implementation guide. I’m not sure which data collator I should use. I believe that DataCollatorForLanguageModeling is meant for BERT-style MLM, and DataCollatorForSeq2Seq expects you to already have input–output pairs.

I once found another implementation called DataCollatorForT5MLM in examples/flax/language-modeling/run_t5_mlm_flax.py in the Transformers repo, but I can’t find it anymore, probably it has been removed. However, based on my understanding, this seems to be the correct implementation. I am adding the code below, please let me know if this is the right implementation.

class DataCollatorForT5MLM:

class DataCollatorForT5MLM:
    """
    Data collator used for T5 span-masked language modeling.
    It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
    For more information on how T5 span-masked language modeling works, one can take a look
    at the `official paper <https://huggingface.co/papers/1910.10683>`__
    or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .

    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        noise_density (:obj:`float`):
            The probability with which to (randomly) mask tokens in the input.
        mean_noise_span_length (:obj:`float`):
            The average span length of the masked tokens.
        input_length (:obj:`int`):
            The expected input length after masking.
        target_length (:obj:`int`):
            The expected target length after masking.
        pad_token_id: (:obj:`int`):
            The pad token id of the model
        decoder_start_token_id: (:obj:`int):
            The decoder start token id of the model
    """

    def __init__(self, tokenizer, noise_density, mean_noise_span_length, input_length, target_length, pad_token_id, decoder_start_token_id):
        self.tokenizer = tokenizer
        self.noise_density = noise_density
        self.mean_noise_span_length = mean_noise_span_length
        self.input_length = input_length
        self.target_length = target_length
        self.pad_token_id = pad_token_id
        self.decoder_start_token_id = decoder_start_token_id
    def shift_tokens_right(self,input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int) -> torch.Tensor:
        """
        Shift input ids one token to the right (used for preparing decoder inputs).
        """
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)

        # Shift everything one step to the right
        shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()

        # Put decoder_start_token_id at position 0
        shifted_input_ids[:, 0] = decoder_start_token_id

        # Replace -100 values with pad_token_id
        shifted_input_ids = shifted_input_ids.masked_fill(shifted_input_ids == -100, pad_token_id)

        return shifted_input_ids
    def __call__(self, examples: list[dict[str, np.ndarray]]) -> BatchEncoding:
        # convert list to dict and tensorize input
        batch = BatchEncoding(
            {k: torch.tensor([examples[i][k] for i in range(len(examples))], dtype=torch.long) for k, v in examples[0].items()}
        )

        input_ids = batch["input_ids"]
        batch_size, expandend_input_length = input_ids.shape

        mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
        labels_mask = ~mask_indices

        input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
        labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))

        batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
        batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)

        if batch["input_ids"].shape[-1] != self.input_length:
            raise ValueError(
                f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but"
                f" should be {self.input_length}."
            )

        if batch["labels"].shape[-1] != self.target_length:
            raise ValueError(
                f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be"
                f" {self.target_length}."
            )

        # to check that tokens are correctly preprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
        batch["decoder_input_ids"] = self.shift_tokens_right(
            batch["labels"], self.pad_token_id, self.decoder_start_token_id
        )
        # print(f"Shape of batch['input_ids']: {batch['input_ids'].shape}")
        # print(f"Shape of batch['labels']: {batch['labels'].shape}")
        # print(f"Shape of batch['decoder_input_ids']: {batch['decoder_input_ids'].shape}")

        return batch

    def create_sentinel_ids(self, mask_indices):
        """
        Sentinel ids creation given the indices that should be masked.
        The start indices of each mask are replaced by the sentinel ids in increasing
        order. Consecutive mask indices to be deleted are replaced with `-1`.
        """
        start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
        start_indices[:, 0] = mask_indices[:, 0]

        sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
        sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
        sentinel_ids -= mask_indices - start_indices

        return sentinel_ids

    def filter_input_ids(self, input_ids, sentinel_ids):
        """
        Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
        This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
        """
        batch_size = input_ids.shape[0]

        input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
        # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
        # masked tokens coming after sentinel tokens and should be removed
        input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
        input_ids = np.concatenate(
            [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
        )
        # return input_ids
        return torch.tensor(input_ids, dtype=torch.long)

    def random_spans_noise_mask(self, length):
        """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .

        Noise mask consisting of random spans of noise tokens.
        The number of noise tokens and the number of noise spans and non-noise spans
        are determined deterministically as follows:
        num_noise_tokens = round(length * noise_density)
        num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
        Spans alternate between non-noise and noise, beginning with non-noise.
        Subject to the above restrictions, all masks are equally likely.

        Args:
            length: an int32 scalar (length of the incoming token sequence)
            noise_density: a float - approximate density of output mask
            mean_noise_span_length: a number

        Returns:
            a boolean tensor with shape [length]
        """

        orig_length = length

        num_noise_tokens = int(np.round(length * self.noise_density))
        num_nonnoise_tokens = length - num_noise_tokens
        # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
        num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
        # num_noise_tokens should be less than num_noise_tokens and num_nonnoise_tokens
        num_noise_spans = int(np.round(min(num_noise_tokens, num_nonnoise_tokens) / self.mean_noise_span_length))

        # avoid degeneracy by ensuring positive number of noise spans
        num_noise_spans = max(num_noise_spans, 1)

        # pick the lengths of the noise spans and the non-noise spans

        def _random_segmentation(num_items, num_segments):
            """Partition a sequence of items randomly into non-empty segments.
            Args:
                num_items: an integer scalar > 0
                num_segments: an integer scalar in [1, num_items]
            Returns:
                a Tensor with shape [num_segments] containing positive integers that add
                up to num_items
            """
            mask_indices = np.arange(num_items - 1) < (num_segments - 1)
            np.random.shuffle(mask_indices)
            first_in_segment = np.pad(mask_indices, [[1, 0]])
            segment_id = np.cumsum(first_in_segment)
            # count length of sub segments assuming that list is sorted
            _, segment_length = np.unique(segment_id, return_counts=True)
            return segment_length

        noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
        nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)

        interleaved_span_lengths = np.reshape(
            np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
        )
        span_starts = np.cumsum(interleaved_span_lengths)[:-1]
        span_start_indicator = np.zeros((length,), dtype=np.int8)
        span_start_indicator[span_starts] = True
        span_num = np.cumsum(span_start_indicator)
        is_noise = np.equal(span_num % 2, 1)

        return is_noise[:orig_length]

Trainer

training_args = TrainingArguments()        output_dir="Saved_models/"+args.output_dir,
per_device_train_batch_size=16, # Adjust based on GPU memory
gradient_accumulation_steps=2, # To simulate a larger batch size
max_steps=args.max_step,
logging_steps=20000,
save_steps=20000,

save_total_limit=1,

dataloader_pin_memory=False,

bf16=True, # Use mixed precision for faster training and less memory usage

ddp_find_unused_parameters=False,

report_to="tensorboard",

logging_strategy="steps",

logging_dir="logs/new_log_test",

load_best_model_at_end=True,

eval_strategy="steps",

metric_for_best_model="eval_loss",

per_device_eval_batch_size=16,

eval_steps=20000,  

warmup_steps = args.warm_up_step,

eval_accumulation_steps=1,

disable_tqdm=False,

log_level="info",)



#optimizer

optimizer = Adafactor(
     model.parameters(),
     lr=args.lr,
     scale_parameter=False,
     relative_step=False,
     warmup_init=False)

# Create the inverse square root scheduler
scheduler = get_inverse_sqrt_schedule(
       optimizer,
       num_warmup_steps=training_args.warmup_steps,)

print("trainer init")

trainer = Trainer(

    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset,
    processing_class=tokenizer,
    data_collator=data_collator, # You'd put your custom T5 data collator here
    optimizers=(optimizer, scheduler)
    )

print("training started")

trainer.train()

My second question is , if this implementation is correct, should I use the same DataCollator class for fine-tuning too, or should I use DataCollatorForSeq2Seq for translation work?

2 Likes

this seems to be the correct implementation. I am adding the code below, please let me know if this is the right implementation.

Seems right implementation.

How can I pretrain T5 model?

Maybe like this.

1 Like

Thank you so much, that was really helpful beyond my expectations.

I have one more question. I have pretrained the T5 model on biological sequences. Now, for translation work, do I need to add prefixes? From my understanding, the custom tokenizer would not work with natural language prefixes, so if I fine-tune one model for A→B and another model for B→A, will I still need to add natural language prefixes? (Although it seems that adding prefixes might allow a single model to translate between multiple languages (A ↔ B), I am not sure how to add the prefixes.)

1 Like

I’m glad if it was helpful.:grinning_face:


Short answer: use prefixes, but they don’t have to be natural language. With a biology tokenizer, add task control tokens like <A2B> and <B2A> as special tokens and prefix every source sequence with the right one. This matches T5’s “task-as-text” design and the multilingual NMT practice of a target-language tag. It lets one model handle A↔B; if you train two one-direction models, the prefix is optional but still useful for clarity and future multi-tasking. (arXiv)

How to do it

  1. Add control tokens to your tokenizer and resize embeddings.
control_tokens = ["<A2B>", "<B2A>"]          # pick clear names; keep them short
tokenizer.add_special_tokens({"additional_special_tokens": control_tokens})
model.resize_token_embeddings(len(tokenizer))

Hugging Face supports adding new special tokens explicitly; do not add them as ordinary tokens. (Hugging Face)

  1. Build inputs with the control prefix.
def make_example(src_ids, direction):
    prefix = "<A2B>" if direction=="A2B" else "<B2A>"
    # encode the prefix + separator if you use one
    prefix_ids = tokenizer.convert_tokens_to_ids(prefix)
    # If you also want a delimiter, add another special token, e.g., "<SEP>"
    return [prefix_ids] + src_ids

This is the same conditioning idea as T5’s “translate English to German:” string, just using your own tokens. (arXiv)

  1. Train a single model for A↔B:
  • Always prefix the source with the tag indicating the target (e.g., <B2A> before B→A examples). This mirrors Google’s multilingual NMT target-language token and mBART’s language-code tokens. (arXiv)
  1. If you keep two directional models:
  • You can skip prefixes, but adding the right tag still reduces ambiguity and keeps the door open to merge corpora later. The overhead is negligible since tags are single tokens. The precedent from multilingual NMT shows these tags improve robustness and enable zero-shot behavior when you expand tasks. (arXiv)

Practical tips

  • Make the control tokens short and unique. Add all planned tags up front to avoid repeated embedding resizing. HF’s tokenizer docs and “added tokens” API cover this. (Hugging Face)
  • Keep the same prefix convention across training, eval, and inference. T5 expects task conditioning in the input string. (arXiv)
  • Biological T5 precedents: ProtT5 and related protein LMs use T5 architectures on non-NL tokens; downstream tasks often rely on task-specific prompts or tags. The architecture does not require natural language. (Hugging Face)

Bottom line: add special token prefixes like <A2B>/<B2A> and use a single model for both directions. This aligns with T5’s task-prefix design and multilingual translation best practice. (arXiv)