Error when using transform function of pixel_values

Hi,
I am working on Flickr 8k Dataset | Kaggle and
have the following dataset:

image_dataset:
DatasetDict({
    train: Dataset({
        features: ['pixel_values', 'labels'],
        num_rows: 7281
    })
    test: Dataset({
        features: ['pixel_values', 'labels'],
        num_rows: 810
    })
})

normalize = Normalize(
    mean=feature_extractor.image_mean, 
    std=feature_extractor.image_std
)

vit_model = ViTModel.from_pretrained(config.ENCODER)
feature_extractor = AutoImageProcessor.from_pretrained(config.ENCODER)
tokenizer = AutoTokenizer.from_pretrained(config.DECODER)
tokenizer.pad_token = tokenizer.unk_token

_transforms = Compose([
    RandomResizedCrop(feature_extractor.size, scale=[0.8, 1]),
    RandomHorizontalFlip(),
    ToTensor(),
    normalize
])

when running the following command:
plt.imshow(_transforms(image_dataset[‘train’][0][‘pixel_values’]).detach().numpy().transpose(1, 2, 0))

i get error: 
 Traceback (most recent call last) ────────────────────────────────╮
│ /tmp/ipykernel_29/84940155.py:1 in <module>                                                      │
│                                                                                                  │
│ [Errno 2] No such file or directory: '/tmp/ipykernel_29/84940155.py'                             │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/torchvision/transforms/transforms.py:95 in __call__      │
│                                                                                                  │
│     92 │                                                                                         │
│     93 │   def __call__(self, img):                                                              │
│     94 │   │   for t in self.transforms:                                                         │
│ ❱   95 │   │   │   img = t(img)                                                                  │
│     96 │   │   return img                                                                        │
│     97 │                                                                                         │
│     98 │   def __repr__(self) -> str:                                                            │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/torchvision/transforms/transforms.py:980 in forward      │
│                                                                                                  │
│    977 │   │   │   PIL Image or Tensor: Randomly cropped and resized image.                      │
│    978 │   │   """                                                                               │
│    979 │   │   i, j, h, w = self.get_params(img, self.scale, self.ratio)                         │
│ ❱  980 │   │   return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=  │
│    981 │                                                                                         │
│    982 │   def __repr__(self) -> str:                                                            │
│    983 │   │   interpolate_str = self.interpolation.value                                        │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/torchvision/transforms/functional.py:663 in resized_crop │
│                                                                                                  │
│    660 │   if not torch.jit.is_scripting() and not torch.jit.is_tracing():                       │
│    661 │   │   _log_api_usage_once(resized_crop)                                                 │
│    662 │   img = crop(img, top, left, height, width)                                             │
│ ❱  663 │   img = resize(img, size, interpolation, antialias=antialias)                           │
│    664 │   return img                                                                            │
│    665                                                                                           │
│    666                                                                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/torchvision/transforms/functional.py:479 in resize       │
│                                                                                                  │
│    476 │   _, image_height, image_width = get_dimensions(img)                                    │
│    477 │   if isinstance(size, int):                                                             │
│    478 │   │   size = [size]                                                                     │
│ ❱  479 │   output_size = _compute_resized_output_size((image_height, image_width), size, max_si  │
│    480 │                                                                                         │
│    481 │   if (image_height, image_width) == output_size:                                        │
│    482 │   │   return img                                                                        │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/torchvision/transforms/functional.py:387 in              │
│ _compute_resized_output_size                                                                     │
│                                                                                                  │
│    384 │   │                                                                                     │
│    385 │   │   new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)         │
│    386 │   else:  # specified both h and w                                                       │
│ ❱  387 │   │   new_w, new_h = size[1], size[0]                                                   │
│    388 │   return [new_h, new_w]                                                                 │
│    389                                                                                           │
│    390                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 1

how to resolve this issue.
i even tried: Image classification using LoRA with no success
Please let me know if you need any further information
Thanks,
Ankush Singal