Background
The FNet is a efficient transformer model that uses Fourier Transform to replace self-attention.
The paper compare it with bert, and showed it “trains nearly seven times faster on GPUs and twice as fast on TPUs”, it’s also more memory efficient and faster for inference.
The authors released their code and pre-trained weights recently: (It uses flax!)
However, it’s not yet capable of generation.
Here is the last paragraph of the FNet paper.
Throughout this work we have restricted our attention to encoders. FNet decoders can be designed
using causally masked discrete Fourier Transform matrices. However, in settings where the FFT is preferable, a lower level implementation would be required to introduce causal masking. Lastly, designing the equivalent of encoder-decoder cross-attention remains an avenue for future work, as
evidence suggests that it is crucial to decoder performance (You et al., 2020
[2105.03824] FNet: Mixing Tokens with Fourier Transforms
Proposal
Here is the main thing I want to try:
- Use a hybrid encoder-decoder where the encoder is FNet, and the decoder can be a normal transformer with attention. The motivation is to take advantage of the more efficient architecture and augment it with generation capability.
Details
- Language: monolingual, english
- Model: encoder-decoder, replaces encoder self-attention with Fourier Transform.
- Data: c4 or the pile
- Training scripts: available in the repo above.
- What are possible challenges? The code is encoder only, we need to add the decoder part.
- Desired project outcome: We will train an efficient model for generation tasks.
Discussion
Some more details for discussion:
- The ideal loss could be prefix-LM or bart-like, so it can generate longer text better. We can compare the prefix-LM loss/inference cost with GPT2. (Similar to making GPT2 faster by encoding prefix with FNet.)
- We can first train the hybrid model and initialize the encoder with pre-trained weight and make sure our code is correct (we need to write some modeling code).
- Later we can train the model from scratch, which we can follow T5 model sizes, and see how it compares. (should be more efficient, we can try to scale it slightly different to match the compute cost/performance)
Other easier things we can do:
- Reproduce it.
- Train it on the Pile (the original one is trained on C4).
Note
- The model is not yet on huggingface/transformers, but I would like to help adding it. 🌟 New model addition: FNet · Issue #12411 · huggingface/transformers · GitHub