Hello, I am trying to use Faiss indexing for an image similarity application. However, I am seeing an Attribute Error. Here’s the code that I am trying
# Loading dataset
dataset = load_dataset("imagefolder", data_dir="/data")
# Pre processing images
def transforms(examples):
examples["image"] = [image.resize((384, 384)).convert("RGB") for image in examples["image"]]
return examples
# Applying the transform function to dataset
dataset = dataset.map(transforms, batched=True, batch_size=48)
candidate_subset = dataset["train"].shuffle(seed=seed).select(range(num_samples))
test_dir = '/data/test'
train_dir = '/data/train'
model_ckpt = '/model'
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
hidden_dim = model.config.hidden_size
# Defining the embedding extractor using the model
def extract_embeddings(model: torch.nn.Module):
"""Utility to compute embeddings."""
device = model.device
def pp(batch):
images = batch["image"]
image_batch_transformed = torch.stack(
[transformation_chain(image) for image in images]
)
new_batch = {"pixel_values": image_batch_transformed.to(device)}
with torch.no_grad():
embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
return {"embeddings": embeddings}
return pp
# Applying embedding extractor train dataset
extract_fn = extract_embeddings(model.to(device))
candidate_subset_emb = candidate_subset.map(extract_fn, batched=True, batch_size=batch_size)
Up until the above code we have extracted the embeddings for the candidate images.
Next, we get embedding from a test image, and apply faiss indexing
random_query_image = dataset["test"].shuffle(seed=seed).select(range(1))
random_query_image_emb = random_query_image.map(extract_fn, batched=True, batch_size=batch_size)
candidate_subset_emb.add_faiss_index(column='embeddings')
So far things are great.
However, this next line of code is giving an issue
scores, retrieved_examples = candidate_subset_emb.get_nearest_examples(
"embeddings", random_query_image_emb, k=5
)
Here’s the traceback:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[13], line 2
1 candidate_subset_emb.add_faiss_index(column='embeddings')
----> 2 scores, retrieved_examples = candidate_subset_emb.get_nearest_examples(
3 "embeddings", random_query_image_emb, k=5
4 )
File ~/miniconda3/envs/cortex-swin/lib/python3.10/site-packages/datasets/search.py:742, in IndexableMixin.get_nearest_examples(self, index_name, query, k, **kwargs)
727 """Find the nearest examples in the dataset to the query.
728
729 Args:
(...)
739 - examples (`dict`): The retrieved examples.
740 """
741 self._check_index_is_initialized(index_name)
--> 742 scores, indices = self.search(index_name, query, k, **kwargs)
743 top_indices = [i for i in indices if i >= 0]
744 return NearestExamplesResults(scores[: len(top_indices)], self[top_indices])
File ~/miniconda3/envs/cortex-swin/lib/python3.10/site-packages/datasets/search.py:702, in IndexableMixin.search(self, index_name, query, k, **kwargs)
687 """Find the nearest examples indices in the dataset to the query.
688
689 Args:
(...)
699 - indices (`List[List[int]]`): The indices of the retrieved examples.
700 """
701 self._check_index_is_initialized(index_name)
--> 702 return self._indexes[index_name].search(query, k, **kwargs)
File ~/miniconda3/envs/cortex-swin/lib/python3.10/site-packages/datasets/search.py:356, in FaissIndex.search(self, query, k, **kwargs)
353 if len(query.shape) != 1 and (len(query.shape) != 2 or query.shape[0] != 1):
354 raise ValueError("Shape of query is incorrect, it has to be either a 1D array or 2D (1, N)")
--> 356 queries = query.reshape(1, -1)
357 if not queries.flags.c_contiguous:
358 queries = np.asarray(queries, order="C")
AttributeError: 'Dataset' object has no attribute 'reshape'
I tried upgrading datasets package. That didn’t work. Any help would be appreciated. Thanks.