What does increasing number of heads do in the Multi-head Attention?

can someone explain to me the point of number of heads in the MultiheadAttention?

what happens if I increase or decrease them?

would it change the number of learnable parameters?

what is the intuition behind increasing or decreasing the number of heads in the MultiheadAttention?

Changing the number of heads changes the number of learnable parameters. If you have more heads, training will take longer. This is definitely true. The next bit is more of an opinion.

When you have several heads per layer the heads are independent of each other. This means that the model can learn different patterns with each head. For example, one head might pay most attention to the next word in each sentence, and another head might pay attention to how nouns and adjectives combine.

Having several heads per layer is similar to having several kernels in convolution.

Having several heads per layer allows one model to try out several pathways at once. It often turns out that some of the heads are not doing anything useful, but that’s OK because the later layers can learn to ignore the un-useful heads.

It is possible to train a model with lots of heads and then cut some of them away. This is called pruning. Note that some researchers prune away whole heads, and other researchers prune away the least useful weights within a head. I believe you can prune a model either after pre-training or after fine-tuning.

If you want to look at what patterns each individual head is learning, I recommend a visualisation tool called Bertviz by Jesse Vig. (Note that this only works for pytorch, not in Tensorflow)

5 Likes

thank you! so I should look at the number of heads as a hyperparameter. and basically, a large or small nhead cannot specifically lead to a better or worse generalization?

Yes, you can treat it as a hyperparameter. I prefer to think of it as a design decision. If you want to take advantage of transfer learning, then you will probably be copying someone else’s pre-trained model, which means the nheads is already fixed to whatever that someone else has chosen. However, if you have several such pre-trained models to choose from, you can try them all and pick the optimum.

For example, if you want to use BERT, there are many different pre-trained models (Base, Large, Tiny, Small, Medium) as well as several pre-trained BERT-like variations (DistilBERT, RoBERTa, ALBERT). You can check out the nheads and the nlayers for any of those.

For my analysis, I used BERT-base and then I compared it with DistilBERT. The BERT-base gave slightly better values, but took much longer to train. This is pretty much what the DistilBERT authors claim.

Up to a point, more heads is likely to mean better results, so long as you have enough data. However, if a BERT-small model has enough parameters to learn all the signal that there is in the data, then adding more heads (for example by using BERT-base) won’t be able to do any better.

3 Likes

I believe, changing number of heads does not change the learnable parameters, because we divide the embedding dimension by #heads, so if you have a model of D_embd= 192, #heads=3, then each head dim would equals 64, while having 6 heads leads to have heads of dim=32 and at the end after self-attension we concatenate the output of all the heads to get back the same d_embd of 192 again.