Seeing AttributeError: 'Dataset' object has no attribute 'reshape' when using "dataset.get_nearest_examples"

Hello, I am trying to use Faiss indexing for an image similarity application. However, I am seeing an Attribute Error. Here’s the code that I am trying

# Loading dataset
dataset = load_dataset("imagefolder", data_dir="/data")

# Pre processing images
def transforms(examples):
    examples["image"] = [image.resize((384, 384)).convert("RGB") for image in examples["image"]]
    return examples

# Applying the transform function to dataset
dataset = dataset.map(transforms, batched=True, batch_size=48)

candidate_subset = dataset["train"].shuffle(seed=seed).select(range(num_samples))

test_dir = '/data/test'
train_dir = '/data/train'
model_ckpt = '/model'
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
hidden_dim = model.config.hidden_size

# Defining the embedding extractor using the model
 def extract_embeddings(model: torch.nn.Module):
    """Utility to compute embeddings."""
    device = model.device
    def pp(batch):
        images = batch["image"]
        image_batch_transformed = torch.stack(
            [transformation_chain(image) for image in images]
        )
        new_batch = {"pixel_values": image_batch_transformed.to(device)}
        with torch.no_grad():
            embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
        return {"embeddings": embeddings}

    return pp

# Applying embedding extractor train dataset
extract_fn = extract_embeddings(model.to(device))
candidate_subset_emb = candidate_subset.map(extract_fn, batched=True, batch_size=batch_size)

Up until the above code we have extracted the embeddings for the candidate images.

Next, we get embedding from a test image, and apply faiss indexing

random_query_image = dataset["test"].shuffle(seed=seed).select(range(1))
random_query_image_emb = random_query_image.map(extract_fn, batched=True, batch_size=batch_size)

candidate_subset_emb.add_faiss_index(column='embeddings')

So far things are great.

However, this next line of code is giving an issue

scores, retrieved_examples = candidate_subset_emb.get_nearest_examples(
    "embeddings", random_query_image_emb, k=5
)

Here’s the traceback:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[13], line 2
      1 candidate_subset_emb.add_faiss_index(column='embeddings')
----> 2 scores, retrieved_examples = candidate_subset_emb.get_nearest_examples(
      3     "embeddings", random_query_image_emb, k=5
      4 )

File ~/miniconda3/envs/cortex-swin/lib/python3.10/site-packages/datasets/search.py:742, in IndexableMixin.get_nearest_examples(self, index_name, query, k, **kwargs)
    727 """Find the nearest examples in the dataset to the query.
    728 
    729 Args:
   (...)
    739     - examples (`dict`): The retrieved examples.
    740 """
    741 self._check_index_is_initialized(index_name)
--> 742 scores, indices = self.search(index_name, query, k, **kwargs)
    743 top_indices = [i for i in indices if i >= 0]
    744 return NearestExamplesResults(scores[: len(top_indices)], self[top_indices])

File ~/miniconda3/envs/cortex-swin/lib/python3.10/site-packages/datasets/search.py:702, in IndexableMixin.search(self, index_name, query, k, **kwargs)
    687 """Find the nearest examples indices in the dataset to the query.
    688 
    689 Args:
   (...)
    699     - indices (`List[List[int]]`): The indices of the retrieved examples.
    700 """
    701 self._check_index_is_initialized(index_name)
--> 702 return self._indexes[index_name].search(query, k, **kwargs)

File ~/miniconda3/envs/cortex-swin/lib/python3.10/site-packages/datasets/search.py:356, in FaissIndex.search(self, query, k, **kwargs)
    353 if len(query.shape) != 1 and (len(query.shape) != 2 or query.shape[0] != 1):
    354     raise ValueError("Shape of query is incorrect, it has to be either a 1D array or 2D (1, N)")
--> 356 queries = query.reshape(1, -1)
    357 if not queries.flags.c_contiguous:
    358     queries = np.asarray(queries, order="C")

AttributeError: 'Dataset' object has no attribute 'reshape'

I tried upgrading datasets package. That didn’t work. Any help would be appreciated. Thanks.

random_query_image_emb (the 2nd argument in get_nearest_examples) must be a NumPy array.

That worked, thank you.
Could you please tell me which distance function is used by Dataset.get_nearest_examples()?

Could I specify which distance function I want to use for calculating the distance, for example if I want to use cosine similarity.

If you could please point me to the documentation for the package that’ll great. Thanks.

The default metric type (distance) for the default index (faiss.IndexFlat) is Euclidean (L2) distance.

The metric_type parameter lets you change this setting:

import faiss
...
ds.add_faiss_index(column, metric_type=faiss.METRIC_INNER_PRODUCT)

In particular, this Faiss doc page explains how to compute cosine similarity:

  1. run .map on the dataset to normalize the embeddings column (with faiss.normalize_L2)
  2. add index with .add_faiss_index(column='embeddings', metric_type=faiss.METRIC_INNER_PRODUCT)
  3. normalize query vectors (with faiss.normalize_L2) before passing them to .get_nearest_examples

You can find the datasets documentation here.