The Correct Attention Mask For Examples Packing

Sometimes we pack multiple short examples in the same input sequence to increase training efficiency(so we don’t need to waste computation in adding and computing paddings in attention).
For example, assume we have sample 1 = [‘a’, ‘b’, ‘c’, ‘<eos>’], sample 2 = [‘d’, ‘e’, ‘f’, ‘<eos>’], we can pack them into a new packed_sample = [‘a’, ‘b’, ‘c’, ‘<eos>’, ‘d’, ‘e’, ‘f’, ‘<eos>’].
This procedure is quite simple, and I think ConstantLengthDataset and group_texts(examples) function in examples/pytorch/language-modeling/run_clm.py can do it well, but I think during training we can’t just use the original casual mask(triangle) for it, otherwise packed samples will attend to information from other samples, which is unwanted. To be more specific, for packed_sample = [‘a’, ‘b’, ‘c’, ‘<eos>’, ‘d’, ‘e’, ‘f’, ‘<eos>’], I think the correct mask should be like this


Am I correct? If yes, is there any element implementation for this? Any idea will be really appreciated, I can implement it myself.:smiley:

2 Likes

Hi @OrionZheng, were you able to find the answer to this question?
I also have the same query.

Hi @OrionZheng , did you figure out how to achieve it? I am guessing huggingface should have support for this.

what is the result now?

In case someone is interested, yes @OrionZheng’s idea is correct. See more details here functionary/functionary/train/packing at main · MeetKai/functionary · GitHub

1 Like

If anyone is interested, I also wrote a short blog post about how it could be done in pytorch: https://huggingface.co/blog/sirluk/llm-sequence-packing

1 Like

In the case of packing, why can’t an attention matrix like [1,1,1,1,2,2,2,2] be passed and interpreted for self attention (consistent with the original post)?

The code for AttentionMaskConverter is a bit confusing and the 4D output is unclear to me if this happens correctly already or not. Does anyone know?

In this example, to make it shorter to write…
[1,1,2,2] translates to

[[[[0, m, m, m],
[0, 0, m, m],
[0, 0, m, m],
[0, 0, m, m]]]]
Where m = minimum or -3.4028e38

That seems like it says sample 2 has attention to sample 1, which is not desired here.

Thoughts? Am I messing up the interpretation?

Thanks!

1 Like