Train the best ever transformer-VAE

Use the Funnel Transformer + T5 model from the huggingface hub with some subclassing to convert them into a VAE for text.

The current SOTA VAE is OPTIMUS which still suffers from some posterior collapse.

Its training data is all open source.

From my experiments an MMD-VAE doesn’t suffer form as much posterior collapse in smaller scale models, why not try & scale up?

I’ve actually already got the code working to turn them into a Pytorch MMD-VAE so why not just convert it to JAX/Flax?

4 Likes

Maybe this could make the basis for a new kind of search engine?

Thanks Mina!

Are you in the slack? Would be good to make a group for this.

Would love to contribute to this. I have worked in VAE before (here is a paper I wrote to handle the posterior collapse issue). Pinged you Fraser in slack. Would love to chat when you are available.

1 Like

@Fraser very interesting idea, I have good experience in T5 and also worked on vae. And I would love to part of this project.

1 Like

I’ve started a Slack so we can make plans and form a proper team.

Would be good to organise a group call when your free?

https://join.slack.com/t/transformer-vae/shared_invite/zt-s5yv7h9h-~d3m7UJlfVPu5tlj8iUvvQ

Hey Mina, I’ve started a Slack so we can make plans and form a proper team.

Would be good to organise a call when your free?

https://join.slack.com/t/transformer-vae/shared_invite/zt-s5yv7h9h-~d3m7UJlfVPu5tlj8iUvvQ

Hi,I would love to contribute to this work! I am interested in VAEs & transformers and already worked with both (never at the same time though).

1 Like

Hello @Fraser & Team,
I am interested to be a part of such an amazing project & team. I will try my best to contribute to this VAE project . It would be nice if we could discuss some more learning resources that would be useful for this project.I can work in any time zone that is comfortable for everyone in the team. I also read your awesome article,’ Interpolating the internet’ & found it very interesting…!!!

1 Like

I’ve made a more clear project post here, be sure to give it a like if your interested in helping out!

Think this is a really cool project - let’s define it officially! Sadly we are limited to using TPU - if it’s too complicated to turn the code into JAX maybe options with PyTorch/XLA can be explored as well …

But cool idea, let’s define it :slight_smile:

1 Like

Hi Patrick, thanks for the feedback!

I’ve linked a revised plan bellow.

The idea here is just to take an existing flax-T5 model and stick an autoencoder between the encoder & decoder. As seen here.

I’ve currently had calls with 3 other team members and we’re really exited to see what this produces!

Forgot to include memory requirements…

Preferably will train on a dataset of input & output sequences with length 256, batch size 24, with a T5-base model.

T5-base has 220M params while OPTIMUS has 227M and was trained with the same params above using 8 v100GPUs.
The TPUv3-8 is equivalent to 4 v100 GPUs so should be able to train with at least batch size 12 or with shorter sequences.

OPTIMUS (current SOTA VAE) https://arxiv.org/pdf/2004.04092.pdf

1 Like

DEMO

Here’s a demo of the model trained on lines of Python code!

https://huggingface.co/spaces/flax-community/t5-vae

1 Like

Hi everyone! This looks so awesome! I am using Deep Neural Networks for composing music. In the past I used MusicVAE (LSTMs), and quite recently GPT2. Your work here would be a nice experimental ground!

Is the slack still active?