Convert a T5 model into a variational autoencoder for text.
I have already made a project that does this in PyTorch but its never been trained at scale.
This project is to convert the autoencoder into Flax so it can be trained efficiently on a TPU to train the largest ever Transformer-VAE!
The model will be trained in english.
Build on T5-base, this will match with the Optimus model.
Only additional parameters come from a small Autoencoder module that will go between the encoder and decoder.
Use the wikipedia sentences dataset from OPTIMUS.
This comes tokenized so we’ll need to use its tokenizer with T5.
The original PyTorch training script is adapted from the old Huggingface clm training script so using the flax clm script should be a good base to build on.
The original model was made with PyTorch so there will be some features that can’t be ported over. E.g. I added a prism layer to the PyTorch code which requires FFTs.
Desired project outcome
A colab notebook where people can explore the Transformer-VAE’s latent space.
Interpolate between sentences.
Transfer the style/content of one sentence to another.
Do gradient descent on a sentences latent code to get the desired sentiment, classification score, etc.
Here are some background links to understand the context behind this project:
and the improvements post.
MMD-VAE, the MMD loss that Transformer-VAE uses.
OPTIMUS Current SOTA text-vae will give a sense of the outputs would should expet.
How does this work?
Currenlty the takes a transformer encoder and decoder and puts a VAE between them.
The VAE forms a compressed latent code which allows interpolating on the training data.
For regulatisation an MMD loss is used instead of KL Divergence which reduces posterior collapse.
I really like this project and all the info provided here! Thanks a lot Let’s try to make this happen! Will ask some people as well
thanks for the awesome information.
thanks my issue has been fixed.