Question about dataset from TFRecord files

Hi everyone,

I have a question regarding this method found in the notebook about training a LLM using TPU.

def prepare_dataset(
    records, decode_fn, mask_fn, batch_size, shuffle, shuffle_buffer_size=None
    num_samples = count_samples(records)
    dataset =
    if shuffle:
        dataset = dataset.shuffle(len(dataset))
    dataset =, num_parallel_reads=auto)
    # TF can't infer the total sample count because it doesn't read
    #  all the records yet, so we assert it here.
    dataset = dataset.apply(
    dataset =, num_parallel_calls=auto)
    if shuffle:
        assert shuffle_buffer_size is not None
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset =, num_parallel_calls=auto)
    dataset = dataset.prefetch(auto)
    return dataset

As we get the dataset as TFRecord files from the GCS bucket, I do not understand why the code first loads it as a classical tensor dataset (, to shuffle it and then loads the shuffled data as a TFRecord dataset (dataset =, num_parallel_reads=auto), to continue the shuffling. Why does the first init from tensor slices is needed here?


Best regards