I’ve created an autoencoder with a vit_base_patch16_224 as encoder. I’m training on around ~100_000 screenshots but after first training step my embeddings are nan. Maybe some of you have some helpfull advice about what I may have done wrong.
I’ve tried weight decay, gradient clipping,low learning rate, nothing works.
here is the git project https://github.com/tomad01/image_autoencoder/tree/dev
Here is the autoencoder class:
class ViTAutoEnc2(nn.Module):
def __init__(self, vit_model='vit_base_patch16_224'):
super(ViTAutoEnc2, self).__init__()
# Use a pre-trained Vision Transformer as the encoder
self.encoder = create_model(vit_model, pretrained=True)
# self.encoder.head = nn.Identity()
# Extract the feature dimension from the ViT model
self.vit_embedding_dim = self.encoder.embed_dim
self.decoder = nn.Sequential(
nn.Linear(768, 768*14*14),
nn.Unflatten(1, torch.Size([768, 14, 14])),
nn.ConvTranspose2d(768, 512, kernel_size=4, stride=2, padding=1), # 14x14 -> 28x28
nn.ReLU(inplace=True),
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 28x28 -> 56x56
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 56x56 -> 112x112
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 112x112 -> 224x224
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 3, kernel_size=1), # Adjust channels
nn.Tanh()
)
count = sum(p.numel() for p in self.decoder.parameters() if p.requires_grad)
print(f'The decoder has {count/10**6:.2f} million trainable parameters')
count = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)
print(f'The encoder has {count/10**6:.2f} million trainable parameters')
# Define normalization parameters (mean and std for ImageNet)
# self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
# self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
def forward(self, x, decode=True):
# Encoder pass with ViT
embeddings = self.encoder.forward_features(x).mean(dim=1) # Get the feature from ViT (batch_size, vit_embedding_dim)
# embeddings = self.encoder(x)
# print(embeddings.shape)
if torch.isnan(embeddings).any():
print("NaN found in encoder output")
embeddings = F.normalize(embeddings, p=2, dim=1, eps=1e-8)
if torch.isnan(embeddings).any():
print("NaN found in normalized embeddings")
# Decoder pass
if decode:
x = self.decoder(embeddings) # Decoding
# x_max = x.max(dim=[1, 2, 3], keepdim=True)[0] # Max value per image, shape: (batch_size, 1, 1, 1)
# x = x / x_max # Divide the tensor by the maximum value
# x = (x - self.mean) / self.std
return x,embeddings
def freeze_encoder(self):
for param in self.encoder.parameters():
param.requires_grad = False
print("Encoder frozen.")
def unfreeze_encoder(self):
for param in self.encoder.parameters():
param.requires_grad = True
print("Encoder unfrozen.")```