Train hybrid FNet for generation (FNet: Mixing Tokens with Fourier Transforms))

Background

The FNet is a efficient transformer model that uses Fourier Transform to replace self-attention.
The paper compare it with bert, and showed it “trains nearly seven times faster on GPUs and twice as fast on TPUs”, it’s also more memory efficient and faster for inference.

The authors released their code and pre-trained weights recently: (It uses flax!)

However, it’s not yet capable of generation.
Here is the last paragraph of the FNet paper.

Throughout this work we have restricted our attention to encoders. FNet decoders can be designed
using causally masked discrete Fourier Transform matrices. However, in settings where the FFT is preferable, a lower level implementation would be required to introduce causal masking. Lastly, designing the equivalent of encoder-decoder cross-attention remains an avenue for future work, as
evidence suggests that it is crucial to decoder performance (You et al., 2020
[2105.03824] FNet: Mixing Tokens with Fourier Transforms

Proposal

Here is the main thing I want to try:

  • Use a hybrid encoder-decoder where the encoder is FNet, and the decoder can be a normal transformer with attention. The motivation is to take advantage of the more efficient architecture and augment it with generation capability.

Details

  • Language: monolingual, english
  • Model: encoder-decoder, replaces encoder self-attention with Fourier Transform.
  • Data: c4 or the pile
  • Training scripts: available in the repo above.
  • What are possible challenges? The code is encoder only, we need to add the decoder part.
  • Desired project outcome: We will train an efficient model for generation tasks.

Discussion

Some more details for discussion:

  • The ideal loss could be prefix-LM or bart-like, so it can generate longer text better. We can compare the prefix-LM loss/inference cost with GPT2. (Similar to making GPT2 faster by encoding prefix with FNet.)
  • We can first train the hybrid model and initialize the encoder with pre-trained weight and make sure our code is correct (we need to write some modeling code).
  • Later we can train the model from scratch, which we can follow T5 model sizes, and see how it compares. (should be more efficient, we can try to scale it slightly different to match the compute cost/performance)

Other easier things we can do:

  • Reproduce it.
  • Train it on the Pile (the original one is trained on C4).

Note

2 Likes

Hi @ttj
Very interesting idea. I am planning to implement the FNet in PyTorch. I made a PR a few days back and I’m yet to update here: [WIP] FNet by gchhablani · Pull Request #12335 · huggingface/transformers · GitHub. Not sure if Flax would be easier.

Hi @gchhablani
Nice! Do you plan to also add the flax version to huggingface/transformers? I think this project is Flax/Jax only.

I like this idea! @gchhablani - do you want to join? Would love to see this project take place :slight_smile:

2 Likes

:sweat_smile: I’m not sure if I’ll be able to devote enough time but I will try. I am also planning to work on another project.
However, @ttj is planning to contribute Flax FNet to the repository and I am planning to work on PyTorch parallely. Wdyt? Is it okay if we only work on Flax? or use the repo directly?

Sounds great - maybe more people will join later on as well! Finalizing it for now :slight_smile:

1 Like