Collate function for tabular data with some text

So I have two levels of ragged tensors. Each example in the batch is table of 1 to 20ish rows. Some of the features have numbers and others have text. So I need to pad the rows to the highest number of rows in the batch. And I need to tokenize and then pad the text features also. And I need to produce masks of course.

And I’m making it into a Tensorflow Dataset.

Anyone who can point me to a good starting point code-wise?

Edit: Actually I have some lengths for the text strings so I could just pad them all to the same length if the dynamic thing is too tricky.


If you look in the .to_tf_dataset method in datasets/ at 2.8.0 · huggingface/datasets · GitHub you’ll find the np_get_batch function which tells you the format that the collate function expects to see the batch in. Spoiler, it’s a list of dictionaries where the column names are keys and the values represent a single row in the dataset. Each row in the dataset is a table for me so I “just” need to pad out the lines in the tables, create masks, tokenize and pad/trunc the text, create masks, join the masks from text and lines, return the batch.
And I need to figure out which format .to_tf_dataset expects to have returned from the collate function. Anyone know?

PS There’s a weird bug where it say

if cols_to_retain is not None:
elif cols_to_retain is not None:

Obviously, the elif will never run since it has the same condition as the preceding if. But I have no idea whether it affects the usability of the code.

PPS There’s a bad thing also (as far as I can tell). I can’t set the size of the shuffle buffer, it is always the length of the dataset. So if the data doesn’t fit in memory then I can’t shuffle before batching the data. And shuffling after batching is suboptimal. And unbatching, shuffling, batching is pretty slow (or at least it used to be). Ideally shuffling would happen as some kind of random sampling without replacement from the arrow table but I’m not sure that’s possible.

PPPS I notice now that HF expects everything to be integers so I can’t use it. Specifically, tf_dataset =, dtype=np.int64)).


Hi @grofte! Firstly, sorry for the delay, and thanks for pointing out the redundant code in to_tf_dataset. I’ll make a PR to clean that up soon! You’re actually one of the very few users who’s delved into my code there, so congratulations!

It’s important not to treat to_tf_dataset as a general solution to creating a, though. That method is designed to wrap HuggingFace Dataset objects specifically, and the code isn’t always able to follow TF best practices. In particular, the use of methods like np_get_batch is a workaround to handle the fact that has no idea how to read our datasets (yet!), so we have to use NumPy/Python code to actually load samples from the dataset. I wouldn’t copy that code for your own data if you already have the data in TF-friendly formats like tf.Tensor or tf.RaggedTensor.

Also, the code tf_dataset =, dtype=np.int64)) is not actually shuffling the dataset itself - it’s shuffling an index array that spans the whole dataset, and then sampling batches of indices and passing them to np_get_batch using tf.numpy_function, which loads those indices from the dataset. As a result, to_tf_dataset doesn’t need to load your whole dataset into memory, and it doesn’t even try to! It only needs to fit the index array into memory, which is going to be tiny for all but the most gigantic datasets.

1 Like

Hi Matthew @Rocketknight1

No worries about delays at all. I ended up using TFio (which I didn’t even know existed). It lets me use Parquet files which again lets me get around how incredibly slow it was to write TFRecord files (of course I should have set it up so incremental data just adds additional files but I hadn’t done that - also wouldn’t work since data gets edited in the source).

I was able to do some group_by_window to get my weirdly shaped data into shape. And I am able to train models as long as I don’t get fancy. Resounding success.

Ooooohhhh. I somehow just read tf_dataset.shuffle(len(dataset)) without noticing that you were doing it on a list of indexes and feeding it to a fetch_function. So clever! I wonder if I could do that in tfio. Probably not.