The most elegant solution I could come up with is overwriting the _pad_points_and_labels
function from the SamProcessor
, passing just the batch of images to the SamImageProcessor
and doing the point padding outside of that function:
def _pad_points_and_labels(input_points, input_labels):
r"""
The method pads the 2D points and labels to the maximum number of points in the batch.
Adapted from processing_sam.py
"""
point_pad_value = -10
input_points = [np.array(point) for point in input_points]
expected_nb_points = max([point.shape[0] for point in input_points])
processed_input_points = []
for i, point in enumerate(input_points):
if point.shape[0] != expected_nb_points:
input_labels[i] = np.append(input_labels[i],
[point_pad_value] * (expected_nb_points - point.shape[0]))
point = np.concatenate(
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
)
input_labels[i] = torch.tensor(input_labels[i])
processed_input_points.append(torch.tensor(point))
input_points = processed_input_points
input_labels = torch.stack(input_labels, dim=0)
input_points = torch.stack(input_points, dim=0).unsqueeze(1)
return input_points, input_labels
This function I use in the collate function as follows:
def _custom_collate_fn(batch):
images = [sample.pop('cv_image') for sample in batch]
points = [sample.pop('point_prompts') for sample in batch]
input_labels = [sample.pop('point_labels') for sample in batch]
# pass the images separately to SamImageProcessor
inputs = image_processor(images, return_tensors="pt")
# pad points and labels outside the image_processor
inputs['input_points'], inputs['input_labels'] = _pad_points_and_labels(points, input_labels)
out_dict = default_collate(batch)
out_dict.update(inputs)
out_dict['cv_image'] = images
return out_dict
Hope this helps.