Question about dataset from TFRecord files

Hi everyone,

I have a question regarding this method found in the notebook about training a LLM using TPU.

def prepare_dataset(
    records, decode_fn, mask_fn, batch_size, shuffle, shuffle_buffer_size=None
):
    num_samples = count_samples(records)
    dataset = tf.data.Dataset.from_tensor_slices(records)
    if shuffle:
        dataset = dataset.shuffle(len(dataset))
    dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=auto)
    # TF can't infer the total sample count because it doesn't read
    #  all the records yet, so we assert it here.
    dataset = dataset.apply(tf.data.experimental.assert_cardinality(num_samples))
    dataset = dataset.map(decode_fn, num_parallel_calls=auto)
    if shuffle:
        assert shuffle_buffer_size is not None
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.map(mask_fn, num_parallel_calls=auto)
    dataset = dataset.prefetch(auto)
    return dataset

As we get the dataset as TFRecord files from the GCS bucket, I do not understand why the code first loads it as a classical tensor dataset (tf.data.Dataset.from_tensor_slices(records)), to shuffle it and then loads the shuffled data as a TFRecord dataset (dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=auto), to continue the shuffling. Why does the first init from tensor slices is needed here?

Thanks

Best regards

Jérôme