MPNet: Inconsistencies between data collator output and masked permute in original MPNet paper

Hi all,

Not sure if this is the right space to ask, but I’m in a bit of pickle. Hopefully there are some experts out there who have worked with MPNet’s architecture a bit more than I have.

I am in the process of converting the original fairseq research pretraining code for MPNet (repo here) into a training loop that is compatible with Huggingface. Although many of the convenience classes already exist in Huggingface (like MPNetForMaskedLM), one thing that has become clear to us is that we will need to port over the collator function written by the research team in MaskedDataset (under tasks/masked_permutation_lm).

In exploring how this collator works, I understand the logic as:

  1. Permute input IDs (based on whole word spans or tokens via arg) and positions
  2. Create masked/corrupted tokens based on the final n indices of the permuted sequence, where n is the prediction size (i.e. seq_len x 0.15 at default values)
  3. Concat these together using concat(seq, mask, mask) and concat(positions, predict_positions, predict_positions)

Using this logic, we might expect the collator function to perform the below operation on some dummy input IDs:

src_tokens = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]

# Once the collator permutes everything and we append the mask portions, we expect something like
new_ids = [ 20,  23,  30,  14,  15,  27,  28,  11,  12,  17,  18,  26,  29,  13, 10,  19,  21,  22,  16,  24,  25, <mask>,  <corrupted>, <mask>, <mask>,  <corrupted>, <mask>]
new_positions = [10, 13, 20,  4,  5, 17, 18,  1,  2,  7,  8, 16, 19,  3,  0,  9, 11, 12, 6, 14, 15,  6, 14, 15,  6, 14, 15]

However, after rereading the MPNet paper, especially section 2.2 and 2.3 with attention on Figure 2, it would SEEM that the output of the collator is incongruous with what is described in these sections.

Figure 2 points out that the content and query masks are built using a permuted sequence that looks like:

src_tokens = [x_1, x_2, x_3, x_4, x_5, x_6]

# Once permuted we get:
new_ids = [x_1, x_3, x_5, <mask>, <mask>, <mask>,  x_4, x_6, x_2]
new_positions = [1, 3, 5, 4, 6, 2, 4, 6, 2]

In this example within the paper, we are masking the pred_len tokens and then appending the content to the end for the content stream. However, the collator output KEEPS the token content in the main sequence, and then adds TWO batches of mask tokens to the end, which to me seems necessarily different than what’s described in the paper. Referring back to our dummy example above, I can outline the discrepancies I’m seeing:

collator_ids = [ 20,  23,  30,  14,  15,  27,  28,  11,  12,  17,  18,  26,  29,  13, 10,  19,  21,  22,  16,  24,  25, <mask>,  <corrupted>, <mask>, <mask>,  <corrupted>, <mask>]
collator_positions = [10, 13, 20,  4,  5, 17, 18,  1,  2,  7,  8, 16, 19,  3,  0,  9, 11, 12, 6, 14, 15,  6, 14, 15,  6, 14, 15]

paper_ids = [ 20,  23,  30,  14,  15,  27,  28,  11,  12,  17,  18,  26,  29,  13, 10,  19,  21,  22, <mask>,  <corrupted>, <mask>, 16, 24, 25]
paper_positions = [10, 13, 20,  4,  5, 17, 18,  1,  2,  7,  8, 16, 19,  3,  0,  9, 11, 12, 6, 14, 15,  6, 14, 15]

My question, then, is this: am I correct in understanding that the collator implementation is different than what’s described in the paper? If so, why?