Jax/Flax VQ autoencoder for Stable Diffusion

Hey, it seems to me that for Stable-Diffusion, only the encoder with KL penalty have been translated to JAX/Flax, any plan to add support to the one with VQ-regularization? Thanks :slight_smile: