Science Tuesday: MARGE

For this science Tuesday, I read Marge, and wrote up a brief summary, as well as some interesting questions to discuss @joeddav @srush @VictorSanh @thomwolf @clem @julien-c @teven @patrickvonplaten @yjernite (only allowed 10 tags)

Pre-training via Paraphrasing (MARGE)

Paper: published June 26 2020
Authors are from Facebook AI Research:
Mike Lewis, Marjan Ghazvininejad, Gargi Ghosh, Armen Aghajanyan, Sida Wang, Luke Zettlemoyer.

Summary

Huge models trained with masked-lm pretraining objective, or similar, memorize lots of facts in their parameters and don’t use an external storage to look up facts they are missing. Human brains have separate systems (it seems) for memorizing facts and generating language, and often google things. In this spirit, goal of many transformer+retriever models is to decompose memorization of facts and language understanding. MARGE stands for a Multi-lingual Autoencoder that Retrieves and GEnerates.

The pretraining setup:

reconstruct original document by retrieving related documents (from wiki) and trying to regenerate the original maximize likelihood of original doc conditional on retrieved docs, relevance scores. This implicitly forces the retriever to learn how to generate good relevance scores.

There are some tricks related to not scoring all of wikipedia for every example while keeping relevant articles in each batch.

Every 10k training steps, they remake their batches by computing the cosine similarity of every pair of docs, and then greedily adding source and target docs to batches such that the pairwise sum of cosine similarities increases the most. This obviously seems hacky, but allows them to get away without approximate NN or some other expensive way to find related docs. This, and the fact that a randomly initialized encoder will give docs with lexical overlap higher than random cosine similarity, allows the model to train from random.

The retrieval model, ideally, can focus on getting the transformer all the facts that it needs while the transformer learns to paraphrase, which requires generating fluent language.

For finetuning/inference, you don’t need to use the retrieval part.

Marge performs…:

  • comparably to XLM-Roberta, with 20% of the pretraining compute.
  • comparably to mbart on de-en, en-zh translation
  • SOTA on ml-sum, a cross lingual summarization task

Key contributions:

(1) Most of the related work is not multilingual

(2) most of the related work does not zero-shot well?

(3) this pretraining objective unifies learning to retrieve and learning to generate. Previous work requires two pretraining stages.

Related Work

Realm: “At a high level, the method goes like this: find the most similar text passages in BERT space, add those passages to the input as additional context, and then make a prediction.” -Joe a few weeks ago

  • different because the retriever has to be pretrained separately. Realm also seems to use mostly open domain QA benchmarks.

RAG (Retrieval-Augmented Generation)

  • Different because mostly focused on knowledge intensive benchmarks. MARGE can also do well on translation.
  • Starts with bart-large + DPR, whereas MARGE pretrains end-to-end.

Questions somebody could answer:

  • Does MARGE outperform Bart on english only benchmarks like GLUE/ xsum summarization? Why did they only show multilingual benchmarks?
  • When will there be code?
  • How long does a forward pass take?
  • What are the consequences of not using retrieval during inference. Does the model not “know” anything?

Higher Level:

  • Is Translation “knowledge intensive”?
  • How could we measure hallucinations?
  • Authors suggest that we should use a pre-training that is as close as possible to the dowstream task. Pegasus paper also suggests this. Where else could this idea be applied?

Also these two talks are good:
https://slideslive.com/38929793/beyond-bert (Mike Lewis at ACL)
https://www.youtube.com/watch?v=KTQPWoQ7Ol8 (Luke Zettlemoyer at AKCD)

12 Likes

From Mike Lewis, the 1st author:

  • We didn’t try very hard, but from what I saw MARGE lags a little behind BART on monolingual English tasks. It’s not too surprising, because I think having to be a good multilingual model just dilutes the capacity a bit. Similarly, XLM-R isn’t quite at RoBERTa level.

  • code coming soon

  • they also retrieve from CC-News, not just wikipedia.

  • “We’re going to look at retrieval during inference, but haven’t run that yet. Qualitatively, I think it’s a bit less prone to hallucination than BART because it (somewhat) knows that it doesn’t know anything. That means we get surprisingly literal zero-shot translations, because it tends not to make too much stuff up.”

2 Likes

Hadn’t read about this. Cool stuff!

Every 10k training steps, they remake their batches by computing the cosine similarity of every pair of docs, and then greedily adding source and target docs to batches such that the pairwise sum of cosine similarities increases the most.

You seem to imply that this is not an expensive operation, but it sounds very expensive: calculate vector for doc, cos sim between all data points greedily. Isn’t that super computationally expensive?

1 Like

In the paper, they separated the dataset into many shards, each of which consists of similar documents, so that they can compute cosine similarity between the documents within the same shards. More generally, instead of shards you can use faiss to cluster the embeddings and compute kNN efficiently.

Also, the forward pass of the embedding costs a fraction of each iteration of training in terms of the computes, so computing the embeddings isn’t expensive, either.

4 Likes

Thanks, I am aware of faiss. We use it in our work, too as an alternative (and addition) to edit distance. It is very fast, but considering the size of the data set this will still take quite some time. If you want to compare all inputs to all other inputs at every x steps, that is still an expensive look up. But if I understand your comment correctly, documents are only compared within the same shards and the shards are created based on some preprocessing that clusters similar documents together. So all inputs are not compared with all the others, but only with those in their own shard.

1 Like

Right. But using faiss for every documents without using shards is actually still fast enough.

Say, the training dataset contains 128 billion tokens. If your batch size is 2M tokens and you update every 10k iters, then you update the knn every 20B tokens. Since the embedder forward pass is about 6x (2x from using half as many layers and 3x from using forward only vs. forward+backward) faster than each iteration per document, the cost of getting embeddings costs as much as the training for 10k iters (128B/6 ~ 20B).

Since the training dataset contains 128 billion tokens, and each document consists of 128 tokens (512 in the paper, so even fewer). Then, you have 1 billion vectors, and as in knn-lm you can use a subset of them for computing the centroids and then search (with GPUs) after quantization as usual. If you take a look at the original paper of faiss, you can see that the computes required for constructing kNN graph of 1 billion vectors is not much … actually about no more than 10 GPU-hours with a single V100, much smaller than what it takes to train the sota LM on 20 billion tokens, so it’s still fast enough relative to the training.

Depending on your perspective, you may argue that this still costs too much or, for example, that batch size is too large in this case. My point is that the frequency of updating the knn is merely the hyperparameter that can be adjusted so as to make the knn part reasonably small. Since it’s not expensive in the case I suggested (which I believe is a reasonable one), MARGE isn’t inherently expensive. You can just make the cost reasonable by investingating the trade-off and find a reasonable compromise.

6 Likes

Interesting! Thanks for the elaborate explanation. I can only encourage and be happy about more efficient models.

1 Like

@BramVanroy @AranKomatsuzaki
I wonder if we can use the same strategy to fine-tune RAG retriever in an end-to-end manner since currently we only fine-tune the doc encoder.