How do I specify weight column when using to_tf_dataset?

Hi, I have a Hugging Face dataset that I would like to convert to a tf.data.Dataset using to_tf_dataset.

The Keras function model.fit() can take a tuple of the format (inputs, targets, sample_weights). However, I do not see an option for choosing the weight column when using to_tf_dataset.

How do i get the tf.data.Dataset to be in the format (inputs, targets, sample_weights)?

cc @Rocketknight1

Hi @9las, this isnā€™t natively supported, but you can make it work pretty easily! The key idea is that the output of to_tf_dataset() is a tf.data.Dataset. This means that you can apply all the standard techniques and transformations to it.

For example, one way to handle this would be:

  1. Add a column to your dataset for sample_weight
  2. Apply to_tf_dataset() to convert your dataset to tf.data.Dataset and make sure you retain that column
  3. Add a transformation at the end using dataset.map() that extracts the sample_weight key from the input and adds it as a third column. For example, something like this:
def extract_sample_weight(inputs, labels):
    weight = inputs.pop("sample_weight")
    return inputs, labels, weight

dataset = dataset.map(extract_sample_weight)

As long as your dataset outputs tuples with three elements, Keras will recognize the third one as the sample weight.