Problems while filtering large datasets using `map`

As part of out flax-jax-community-week project, I encountered a problem when attempting to filter a large dataset to select only the items matching a certain condition. I’m processing the items with map using batched mode, but some of the batches contain no matching items. We need to return an empty result in that case, but I’m having trouble doing so. If I return an empty dictionary {}, then I encounter schema errors. I think they happen when empty batches are followed by non-empty ones. If I return a dictionary with keys containing empty lists {'column_name': []}, then I get an index out of bounds error.

I filed this issue describing the problem in slightly more detail. If this is not a bug, or someone else has found a workaround, that’d be great to know.

1 Like

Thanks for reporting ! Hopefully we can have a fix soon :slight_smile:

I just took a first go at it. Let me know if I’m off track!