Hi everyone,
I have a question regarding this method found in the notebook about training a LLM using TPU.
def prepare_dataset(
records, decode_fn, mask_fn, batch_size, shuffle, shuffle_buffer_size=None
):
num_samples = count_samples(records)
dataset = tf.data.Dataset.from_tensor_slices(records)
if shuffle:
dataset = dataset.shuffle(len(dataset))
dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=auto)
# TF can't infer the total sample count because it doesn't read
# all the records yet, so we assert it here.
dataset = dataset.apply(tf.data.experimental.assert_cardinality(num_samples))
dataset = dataset.map(decode_fn, num_parallel_calls=auto)
if shuffle:
assert shuffle_buffer_size is not None
dataset = dataset.shuffle(shuffle_buffer_size)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.map(mask_fn, num_parallel_calls=auto)
dataset = dataset.prefetch(auto)
return dataset
As we get the dataset as TFRecord files from the GCS bucket, I do not understand why the code first loads it as a classical tensor dataset (tf.data.Dataset.from_tensor_slices(records)), to shuffle it and then loads the shuffled data as a TFRecord dataset (dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=auto), to continue the shuffling. Why does the first init from tensor slices is needed here?
Thanks
Best regards
Jérôme