Actually when we pad images with zeros, the model “sees” those zero areas and this changes the results even for the real parts of the image most likely having empty seats in a meeting that still affect the discussion.
You need to tell the model to ignore the padded areas. Here’s a simple approach:
Simple solution - just add this wrapper around your existing Swin model
import torch
import torch.nn as nn
from transformers import Swinv2Model
class FixedSwinModel(nn.Module):
def init(self):
super().init()
Your existing model
self.swin = Swinv2Model.from_pretrained(“microsoft/swinv2-tiny-patch4-window16-256”)
self.patch_size = 4 # Swin patch size
def forward(self, images, original_sizes):
# Step 1: Run the model normally
outputs = self.swin(images)
# Step 2: Find and remove padded regions
embeddings = outputs.last_hidden_state
masked_embeddings = self.remove_padding_effect(embeddings, images, original_sizes)
# Step 3: Return clean results
outputs.last_hidden_state = masked_embeddings
return outputs
def remove_padding_effect(self, embeddings, images, original_sizes):
"""Remove the effect of zero-padded regions"""
batch_size = images.shape[0]
for i in range(batch_size):
orig_h, orig_w = original_sizes[i]
padded_h, padded_w = images.shape[2], images.shape[3]
# Calculate how many patches are real vs padded
real_patches_h = (orig_h + self.patch_size - 1) // self.patch_size
real_patches_w = (orig_w + self.patch_size - 1) // self.patch_size
total_patches_h = padded_h // self.patch_size
total_patches_w = padded_w // self.patch_size
# Create a mask for real patches
valid_patches = real_patches_h * real_patches_w
total_patches = total_patches_h * total_patches_w
# Zero out embeddings from padded patches
if embeddings.dim() == 3: # [batch, patches, features]
embeddings[i, valid_patches:] = 0
return embeddings
How to use it (replace your current model):
model = FixedSwinModel()
Your existing batch processing with one small change:
def process_batch(image_list):
Store original sizes (add this line)
original_sizes = [(img.shape[1], img.shape[2]) for img in image_list]
# Your existing padding code
max_h = max(img.shape[1] for img in image_list)
max_w = max(img.shape[2] for img in image_list)
padded_images = []
for img in image_list:
padded = torch.nn.functional.pad(img, (0, max_w - img.shape[2], 0, max_h - img.shape[1]))
padded_images.append(padded)
batch = torch.stack(padded_images)
# Updated model call (pass original sizes)
outputs = model(batch, original_sizes)
return outputs
Don’t worry, the fix is actually quite straightforward.
You keep using the same Swin model and padding approach, just add a wrapper that removes the influence of padded areas.
Only need to track original image sizes and pass them to the model and ensures consistent embeddings whether images are padded or not.
Good luck!