The legends over at DeepSpeed released a paper on scaling Mixture of Experts with a bunch of cool ideas.
Since they will probably release some pytorch code soon I wanted to summarize/discuss the findings so that I learn them better.
- I provide 0 background on Mixture of Experts, assume knowledge of Top1 vs Top2 gating, for selfish/lazy reasons. Read the deepspeed blog post for background.
- I abstract the term “acc” to encompass all types of metrics: validation perplexity, zero shot accuracy, etc.
- I used @srush trick of trying to read critically (to get your brain to think harder about other peoples’ results) but I don’t want to come off as too negative. I really enjoyed this paper and am excited to read the code!
The DeepSpeed team proposes:
- (a) (sec 4.1) architectural modifications that reduce the number of experts without hurting acc.
- (b) (sec 4.1) Moe 2 Moe distillation, (instead of MoE 2 dense distillation like the FAIR paper (appendix Table 9) and the Switch paper)
- (c) (sec 5) Systems Optimizations to make inference fast
- Improved Communication Collectives for MoE Inference (hierarchical all2all)
- tutel style single-device kernels to make routing tokens to experts fast.
- 4D parallelism!?
I now cover architecture and distillation, and save systems optimizations for later because I don’t fully understand them yet.
This section is really well written. It contains two very nice ablations that motivated the changes:
Phenomenon 1: “Pyramid”
We compare the performance of two different half-MoE architectures. More specifically, we put MoE layers in the first half of the model and leave the second half’s layers identical to the dense model. We switch the MoE layers to the second half and use dense at the first half. The results show that deeper layers benefit more from large number of experts.
This also saves a ton of parameters: 40% reduction at 1.3B dense equivalent size, which will be useful at inference time.
Phenomenon 2: “Residual”
we can achieve the benefit of using two experts per layer but still use one communication.
They frame this as trying to get the benefits of top2 routing without the costs.
But, basically MoeLayers become only half sparse – a dense ffn that process the input as does 1 expert – the results are added.
Compared to top2 where 2 different sparse experts process the input, this is cheaper because there is less communication (you only need to send the input to 1 place instead of 2?)
Note this does not improve acc compared to top2, just speed.
Putting it all together:
FAIR arch (see table 1) (52B Params)
- Layers: top2 gating (each token gets routed to 2 experts)
- 512 experts at each MoE layer
Deepspeed Arch: (31B params)
- Layers: each token processed by dense FFN and 1 expert (same FLOPs as top2 gating if same number of experts, I believe).
- pyramid: somewhere between 32 and 128 experts at each Moe layer – way fewer params!
In terms of acc, (PIQA is the only overlapping evaluation),
the 31B Deepspeed performs between the FAIR 52B and the FAIR 207B and was probably lower training cost than the 52B, even before all the systems optimizations in section 5. Nice!
With the systems optimizations they say training is 5x faster than dense (to the same acc). The FAIR paper says “4x faster than dense”, but measures TFLOPS, which make the extra communication required for MoE appear to be free. So all in all this definitely seems like a better architecture.
It would have been cool if Tables 2,4 had training cost and inference cost next to the few shot performances (or 1 big joined table somewhere!).
Caveat before you read this section: in most distillation results, the student model is MUCH smaller than the teacher model, like half as large or so. Here, the student model is only 12.5% smaller than the teacher model. (3 fewer layers, 4B fewer params (31B vs 27B)).
They are able to lose very little performance, which is nice, but they also didn’t really lose that much weight, and it would be interesting to try to replicate what they did with smaller students.
Caveat 2: name deeply misleading. It’s normal KD but they switch to cross entropy loss halfway through that’s it!
Anyways, these are the first published MoE 2 MoE Distillation results. The switch paper and FAIR paper both distill Moe 2 Dense models (since they are much easier to serve than MoE models, a gap deepspeed claims to eliminate in section 5 – the one I don’t understand yet:( ).
They use the same KD loss as the other papers, but they turn it off halfway through training.
They say this improves acc, but I am most interested in the speed implications. I tried MoE2MoE distillation but it was extremely slow (like 10x slower than Dense2Dense) because of teacher inference every step.
If we could only run the teacher forward pass for part of the student training, that would be sweet!
Let me know any inaccuracies, important omissions, what you ate for lunch follow up ideas!
Next week I will try to tackle Section 5 (Systems optimizations) and if I don’t I will burn a 20 dollar bill and record it!