Can I train pytorch T5 on TPU with variable batch shape?

My goal is to group sequences with similar length together in a batch and pad them to match the longest one. Will it work on TPU? Since it does not support dynamic shapes as the doc says.

Hi @marton-avrios,
I’ve done exactly this for T5, basing it off the following article:
https://towardsdatascience.com/divide-hugging-face-transformers-training-time-by-2-or-more-21bf7129db9q-21bf7129db9e

Here’s the code:
“”"

from torch.nn.utils.rnn import pad_sequence
def collate_batch(batch):
    pad_token_id = 0
    src_ids = pad_sequence([sample['source_ids'] for sample in batch], batch_first=True, padding_value=pad_token_id)
    src_text = [sample['source_text'] for sample in batch]
    src_mask = pad_sequence([sample['source_mask'] for sample in batch], batch_first=True, padding_value=pad_token_id)

    tgt_ids = pad_sequence([sample['target_ids'] for sample in batch], batch_first=True, padding_value=pad_token_id)
    tgt_ids[tgt_ids[:, :] == 0] = -100
    tgt_mask = pad_sequence([sample['target_mask'] for sample in batch], batch_first=True, padding_value=pad_token_id)
    tgt_text = [sample['target_text'] for sample in batch]

    return {
    'source_ids': src_ids, 
    'target_ids': tgt_ids,
    'source_mask': src_mask, 
    "target_mask": tgt_mask,
    "source_text": src_text, 
    "target_text": tgt_text
    }

“”"

Hmmm…If I understand it correctly it results in batches of variable shape so that’s good. But my concern is that TPU does not support variable tensor shapes. Each shape should be available at compile time and should not depend on input data.