Train the best ever transformer-VAE

Use the Funnel Transformer + T5 model from the huggingface hub with some subclassing to convert them into a VAE for text.

The current SOTA VAE is OPTIMUS which still suffers from some posterior collapse.

Its training data is all open source.

From my experiments an MMD-VAE doesn’t suffer form as much posterior collapse in smaller scale models, why not try & scale up?

I’ve actually already got the code working to turn them into a Pytorch MMD-VAE so why not just convert it to JAX/Flax?


Maybe this could make the basis for a new kind of search engine?

This is a good idea. I have experience and interest in VAE and I would like to contribute to this project.

1 Like

Thanks Mina!

Are you in the slack? Would be good to make a group for this.

I am not in the slack. I have filled the project form but I was not added yet.

Would love to contribute to this. I have worked in VAE before (here is a paper I wrote to handle the posterior collapse issue). Pinged you Fraser in slack. Would love to chat when you are available.

1 Like

@Fraser very interesting idea, I have good experience in T5 and also worked on vae. And I would love to part of this project.

1 Like

I’ve started a Slack so we can make plans and form a proper team.

Would be good to organise a group call when your free?

Hey Mina, I’ve started a Slack so we can make plans and form a proper team.

Would be good to organise a call when your free?

Hi,I would love to contribute to this work! I am interested in VAEs & transformers and already worked with both (never at the same time though).

1 Like

Hello @Fraser & Team,
I am interested to be a part of such an amazing project & team. I will try my best to contribute to this VAE project . It would be nice if we could discuss some more learning resources that would be useful for this project.I can work in any time zone that is comfortable for everyone in the team. I also read your awesome article,’ Interpolating the internet’ & found it very interesting…!!!

1 Like

hey, i would like to be added in this project, this looks really interesting and i would love to share my expertise in this project too. I have worked on VAEs in my own projects and have good amount of experience in this.


I’ve made a more clear project post here, be sure to give it a like if your interested in helping out!

Think this is a really cool project - let’s define it officially! Sadly we are limited to using TPU - if it’s too complicated to turn the code into JAX maybe options with PyTorch/XLA can be explored as well …

But cool idea, let’s define it :slight_smile:

1 Like

Hi Patrick, thanks for the feedback!

I’ve linked a revised plan bellow.

The idea here is just to take an existing flax-T5 model and stick an autoencoder between the encoder & decoder. As seen here.

I’ve currently had calls with 3 other team members and we’re really exited to see what this produces!

Forgot to include memory requirements…

Preferably will train on a dataset of input & output sequences with length 256, batch size 24, with a T5-base model.

T5-base has 220M params while OPTIMUS has 227M and was trained with the same params above using 8 v100GPUs.
The TPUv3-8 is equivalent to 4 v100 GPUs so should be able to train with at least batch size 12 or with shorter sequences.


1 Like


Here’s a demo of the model trained on lines of Python code!

1 Like