Mapping 1 multi-element column of a dataset to multi row dataset with 1 element per row, duplicating other features

Apologies for the spam.

I am currently trying to process an image dataset into a new representation,
Essentially I am converting an image into a set of sub-images, I would like each sub image to be a single example in the dataset.
Currently I have the dataset of the form

Dataset({
    features: [‘coordinates’, ‘filename’, ‘img’, ‘label’, ‘full_label’],
    num_rows: 11969
})

and I have a map function which converts it to:

Dataset({
    features: [‘coordinates’, ‘filename’, ‘sub_images’, ‘label’, ‘full_label’],
    num_rows: 11969
})

where “sub_images” is a list containing n sub-images

I would like to convert this new dataset to the form:

Dataset({
    features: [‘coordinates’, ‘filename’, ‘img’, ‘label’, ‘full_label’],
    num_rows: 11969*n
})

Where each sub-images field “unrolls” into n separate rows, duplicating the corresponding coordinates, filenames, labels and full_labels. I have attempted this with batched mapping with the following function.

def patches_to_examples(example):
    return{"label": [example["label"] for _ in example["sub_images"]], 
           "full_label": [example["full_label"] for _ in example["sub_images"]], 
           "filename": [example["filename"] for _ in example["sub_images"]], 
           "coordinates": [example["coordinates"] for _ in example["sub_images"]], 
           "img":[np.array(image) for image in example["sub_images"]]}

ds = ds.map(patches_to_examples, batched = True, remove_columns = ds.column_names)

however this only creates 1 row per example and stacks the images in a list. where I would like it to create len(sub_images) rows per example, with one image per row.
Any suggestions on where I’m going wrong?
Cheers in advance!

Hi ! This function takes a batch of examples as input, so you have to do two for loops: one to loop over the examples, and one to loop over the sub images:

def patches_to_examples(batch):
    return {
        "label": [label for i, label in enumerate(batch["label"]) for _ in batch["sub_images"][i]], 
        "full_label": [full_label for i, full_label in enumerate(batch["full_label"]) for _ in batch["sub_images"][i]], 
        "filename": [filename for i, filename in enumerate(batch["filename"]) for _ in batch["sub_images"][i]], 
        "coordinates": [coordinates for i, coordinates in enumerate(batch["coordinates"]) for _ in batch["sub_images"][i]], 
        "img":[np.array(sub_image) for sub_images in batch["sub_images"] for sub_image in sub_images]
    }

alternatively you can try using pandas explode function:

ds.with_format("pandas").map(lambda df: df.explode("sub_images"))

(you might need to do df.explode("sub_images").dropna() otherwise empty sub_images might create some NaNs)

Thanks very much!
Problem solved