Hi ! A DatasetDict
is simply a python dictionary of datasets:
train_dataset = concatenate_datasets([
train_examples_A.rename_column("image", "imageA").remove_columns(["label"]),
train_examples_B.rename_column("image", "imageB").remove_columns(["label"])
], axis=1)
test_dataset = concatenate_datasets([
test_examples_A.rename_column("image", "imageA").remove_columns(["label"]),
test_examples_B.rename_column("image", "imageB").remove_columns(["label"])
], axis=1)
datasets = DatasetDict({
"train": train_dataset,
"test": test_dataset
})
You might need to make sure that train_examples_A and train_examples_B have the same length before concatenating them