Use the Funnel Transformer + T5 model from the huggingface hub with some subclassing to convert them into a VAE for text.
The current SOTA VAE is OPTIMUS which still suffers from some posterior collapse.
Its training data is all open source.
From my experiments an MMD-VAE doesn’t suffer form as much posterior collapse in smaller scale models, why not try & scale up?
I’ve actually already got the code working to turn them into a Pytorch MMD-VAE so why not just convert it to JAX/Flax?