Question on SAM model fine tuning

Hi, I am trying to fine tune SAM model with masks containing multiple bounding boxes. But I am getting error like this -

pixel_values torch.Size([2, 3, 1024, 1024])
original_sizes torch.Size([2, 2])
reshaped_input_sizes torch.Size([2, 2])
input_boxes torch.Size([2, 4])
ground_truth_mask torch.Size([2, 512, 512])
batch["ground_truth_mask"].shape = torch.Size([2, 512, 512])
  0%|                                                               | 0/102 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/data/bmi-image/Shyam/Nuclei segmentation/NuINS_run_SAM_4.py", line 214, in <module>
    outputs = model(pixel_values=batch["pixel_values"].to(device),
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3-2020/envs/pytorch-gpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3-2020/envs/pytorch-gpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/deb3cz/.local/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py", line 1338, in forward
    raise ValueError(
ValueError: ('The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.', ' got torch.Size([2, 4]).')

my code is -

list_of_images_final = []
list_of_masks_final = []
list_of_annotations_final = []

for idx, annotations in enumerate(list_of_annotations):
    print(f'fix up the final lists, and now doing for index = {idx}')
    for annotation in annotations:
        print(f'annotation within second loop of annotation = {annotation}')
        list_of_images_final.append(list_of_images[idx])
        list_of_masks_final.append(list_of_masks[idx])
        annotation = [float(i) for i in annotation]
        list_of_annotations_final.append(annotation)
list_of_images_final = np.array(list_of_images_final)
list_of_masks_final = np.array(list_of_masks_final)

dataset_dict = {
    
    "image": [Image.fromarray(img) for img in list_of_images_final],
    "label": [Image.fromarray(mask) for mask in list_of_masks_final],
    "annotations": list_of_annotations_final
    
    }

dataset = Dataset.from_dict(dataset_dict)

class SAMDataset(Dataset):
  """
  This class is used to create a dataset that serves input images and masks.
  It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
  """
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["label"], dtype = np.uint8)
    prompt = item['annotations']
    
    # prepare image and prompt for the model
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
train_dataset = SAMDataset(dataset=dataset, processor=processor)
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, drop_last=False, num_workers=4)
batch = next(iter(train_dataloader))
print(f'batch = {batch}')
for k,v in batch.items():
  print(k,v.shape)

print(f'batch["ground_truth_mask"].shape = {batch["ground_truth_mask"].shape}')

Can anyone please help me figure out what wrong I am doing here?