Convert a T5 model into a variational autoencoder for text.
I have already made a project that does this in PyTorch but its never been trained at scale.
This project is to convert the autoencoder into Flax so it can be trained efficiently on a TPU to train the largest ever Transformer-VAE!
Language
The model will be trained in english.
Model
Build on T5-base, this will match with the Optimus model.
Only additional parameters come from a small Autoencoder module that will go between the encoder and decoder.
Datasets
Use the wikipedia sentences dataset from OPTIMUS.
This comes tokenized so we’ll need to use its tokenizer with T5.
Training scripts
The original PyTorch training script is adapted from the old Huggingface clm training script so using the flax clm script should be a good base to build on.
Challenges
The original model was made with PyTorch so there will be some features that can’t be ported over. E.g. I added a prism layer to the PyTorch code which requires FFTs.
Desired project outcome
A colab notebook where people can explore the Transformer-VAE’s latent space.
Interpolate between sentences.
Transfer the style/content of one sentence to another.
Do gradient descent on a sentences latent code to get the desired sentiment, classification score, etc.
Reads
Here are some background links to understand the context behind this project: