Pretrained language model that enables non-autoregressive generation

Description: What is the project about?

Pre-train a language model for non-autoregressive generation at scale.

Recent pre-trained language models, e.g. BERT, ELECTRA, boost the performance of various natural language tasks. These approaches focused on scaling up massive training data and efficient pre-training. However, the meaning of their objective is not clear and prevents the further step beyond them. Therefore, we have to design another way to build up pre-trained language model with better interpretability.

The ability to generate language could improve the interpretability of a language model by inspecting the generated results. Some language models with generating language in an auto-regressive fashion, e.g. XLNet, are proposed, but it generates tokens one by one, which incurs slow decoding process.

In this project, we implement a (1) theoretically motivated language model with (2) highly parallelized model training while (3) enabling fast decoding scheme. We follow up a recent paper and move this technique to the pre-training stage. We believe that this model could provides better understanding about the research community while efficiently utilizing state-of-the-art accelerator such as TPUv3.

Model: What model do you propose to use?

Transformer layers with conditional random fields. The detailed architecture is inside this paper.
The model architecture huggingface provided is used and the additional conditional random field layer will be implemented as a part of our project.

Dataset: What dataset do you want to use?

The same dataset trained for BERT. It would be wikipedia dataset.

Training scripts: What training scripts do you want to use?

Training scripts will be created as part of the project

Expected result: What do you expect the outcome of the project to be?

A new pre-trained architecture that generates long language quickly with non-autoregressive decoding.

Good idea! I want to join :slight_smile:

Great idea! I also want to join! :smiley:

That’s really fast. I would love to join. :+1:t2:

BTW we can already sample from the pre-trained BERT just with proper parametrizations and sampling schemes (like this). It seems their results are comparable except for the speed which is the main point.

Yeah, I didn’t follow up the paper you mentioned.

Also, I appreciate that you correctly understand our project’s goal, which is efficient decoding scheme.

I will check the paper.

Thanks,

I really like this project idea!

Do you think you will be able to write jax transformer code with conditional random fields in 3,4 days for the project @sh0416 ? Maybe just use FlaxBERT and add the conditional random field layer on top? Think we should keep it as simple as possible :slight_smile:

Also we probably don’t need to pretrain the whole model from scratch, but we can just use a pretrained BERT model and adapt it for fast non-autoregressive generation no?

Think this project could give a really nice demo for fast gneration!

Think this is not the easiest project, but definitely one of the interesting ones! Finalizing it :slight_smile:

Thanks for reading our proposal.

  • Do you think you will be able to write jax transformer code with conditional random fields in 3,4 days for the project?

I plan to use FlaxBERT and implement conditional random field on my own. :slight_smile: We want to make it simple as possible. So, if there exists a code snippet for conditional random field, we’d like to use that for this project.

  • Also we probably don’t need to pretrain the whole model from scratch, but we can just use a pretrained BERT model and adapt it for fast non-autoregressive generation no?

Yes, only fine-tuning pre-trained model could be one simple way to accomplish this project. However, I want to do it from scratch because I didn’t have experience about the pretraining language model due to the lack of computing resource. So, I want to improve my skill for pretraining models through this event. Also, I believe it could be a decent research work if it succeed.

I think building up two milestones for this proposal would be better.

  • (Required) Reproduce the reference paper through fine-tuning
  • (Optional) Move the mechanism to pre-trained phase to enhance the language model performance

Thanks for useful advice @patrickvonplaten, I appreciate it!

Great! Think it’s very much feasible to implement conditional random fields on top of BERT - cool idea!

Regarding pretraining:
PreTraining BERT in English requires quite some time since the English dataset is so massive. Maybe just fine-tuning it makes sense in a first step ? Or further pre-training an already pre-trained English-BERT on some specific data?

Very much looking forward to this project :slight_smile:

The paper conducts just fine-tuning the pre-trained model, so it seems sufficient. Also, the paper mentions that they implements code with pytorch huggingface implementation although the author doesn’t publish their code yet. I think the project is simple enough :smiley:. Also, I expect the jit in jax could additionally improve the generation performance compared to PyTorch’s implementation.

1 Like

wanna join

@patrickvonplaten Could you join @cosmoquester for our project? Thanks for your effort.

1 Like

@harrydrippin @calofmijuck @junhsss @cosmoquester I’ve just created discord channel named “non-ar-generation” in jax-community-week. So, feel free to join.

1 Like

Hi I would like to join your discord server. Could you please add me?

1 Like

added you @VigneshBaskaran :slight_smile:

Hi @patrickvonplaten I just wanted to be added to the discord server to follow their discussion. Please don’t add me on the team for now :slight_smile: I am still exploring all the projects and I haven’t decided which one to join yet. Sorry for the miscommunication

I almost complete the training script for this project. In short, the CRF layer and their approximate version (this was tough because there is no reference implementation…) are now working and provide better performance than that of excluded version.

I list up what I did during the community week.

  • The training curve is logged in Weights & Biases
  • “solar-river-74” is the non-autoregressive BERT without CRF layer.
  • “clear-resonance-78” is the non-autoregressive BERT with CRF layer, consecutive eos token.
  • “clear-resonance-78” is clearly better than “solar-river-74”.
  • Ratio-first decoding is applied in both model.
  • Since the CRF layer required the jax scan, so the training time is longer than the baseline.
  • The evaluation metric is lower than that of the paper reported, but it could be improved with extensive hyper-parameter tuning and larger scale training. (Now, I just trained about 2 epochs.)
  • I will create a demo page until tomorrow using streamlit!

Thanks,
I mentioned @patrickvonplaten because he might want to know some progress about our project…