I'm failing to train a vit_base_patch16_224 model for creating high quality embeddings for screenshots

I’ve created an autoencoder with a vit_base_patch16_224 as encoder. I’m training on around ~100_000 screenshots but after first training step my embeddings are nan. Maybe some of you have some helpfull advice about what I may have done wrong.
I’ve tried weight decay, gradient clipping,low learning rate, nothing works.

here is the git project https://github.com/tomad01/image_autoencoder/tree/dev
Here is the autoencoder class:

class ViTAutoEnc2(nn.Module):
    def __init__(self, vit_model='vit_base_patch16_224'):
        super(ViTAutoEnc2, self).__init__()
        
        # Use a pre-trained Vision Transformer as the encoder
        self.encoder = create_model(vit_model, pretrained=True)
        # self.encoder.head = nn.Identity()
        # Extract the feature dimension from the ViT model
        self.vit_embedding_dim = self.encoder.embed_dim

        self.decoder = nn.Sequential(
            nn.Linear(768, 768*14*14),
            nn.Unflatten(1, torch.Size([768, 14, 14])),
            nn.ConvTranspose2d(768, 512, kernel_size=4, stride=2, padding=1),  # 14x14 -> 28x28
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 28x28 -> 56x56
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 56x56 -> 112x112
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 112x112 -> 224x224
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(32, 3, kernel_size=1),  # Adjust channels
            nn.Tanh()             
         )

        count = sum(p.numel() for p in self.decoder.parameters() if p.requires_grad)
        print(f'The decoder has {count/10**6:.2f} million trainable parameters')
        
        count = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)
        print(f'The encoder has {count/10**6:.2f} million trainable parameters')
                # Define normalization parameters (mean and std for ImageNet)
        # self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
        # self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)


    def forward(self, x, decode=True):
        # Encoder pass with ViT

        embeddings = self.encoder.forward_features(x).mean(dim=1)  # Get the feature from ViT (batch_size, vit_embedding_dim)
        # embeddings = self.encoder(x)
        # print(embeddings.shape)
        if torch.isnan(embeddings).any():
            print("NaN found in encoder output")

        embeddings = F.normalize(embeddings, p=2, dim=1, eps=1e-8)

        if torch.isnan(embeddings).any():
            print("NaN found in normalized embeddings")
        # Decoder pass
        if decode:
            x = self.decoder(embeddings)  # Decoding
            # x_max = x.max(dim=[1, 2, 3], keepdim=True)[0]  # Max value per image, shape: (batch_size, 1, 1, 1)
            # x = x / x_max  # Divide the tensor by the maximum value
            # x = (x - self.mean) / self.std
        return x,embeddings
            
        
    def freeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = False
        print("Encoder frozen.")

    def unfreeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = True
        print("Encoder unfrozen.")```