Hi,
I’m currently working on a project on Vision Transformers and I wrote this: (a fraction of my code)
from transformers import ViTFeatureExtractor, ViTForImageClassification
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from datasets import DatasetDict
dataset = load_dataset("fashion_mnist")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
def transform(examples):
return {
"pixel_values": feature_extractor(examples['image'].convert("RGB"), return_tensors='pt')["pixel_values"],
"labels": examples["label"]
}
dataset = DatasetDict({
"train": dataset["train"].map(transform),
"test": dataset["test"].map(transform)
})
dataset = dataset.remove_columns(["label", "image"])
Apparently after all this I would expect my returned dataset to contain “pixel_values” and “labels” two features, respectively a PyTorch tensor and a list. But when I inspected this dataset:
dataset["train"][:2]
The “pixel_values” appears to be a nested list but not a tensor. This is causing problem in my code after this when I tried to load the data.
Could anyone help me please?