questions on tensor parallelism using pytorch

consider the following case: 8 gpus with ranks 0,1,2,3,4,5,6,7 (i)assume we implement tensor parallelism and data parallelism according to the following scheme : tensor group 0 includes ranks 0,1,2,3 ; tensor group 1 includes ranks 4,5,6,7 ; data parallelism [0, 4], [1, 5], [2, 6], [3, 7] . QUESTIONS :

  1. given this scheme, how will the communication pattern be. Describe from the perspective of each gpu in tensor group 1
  2. what is the difference between the above scheme and a tensor group with all gpus
  3. consider rank #1 and rank#5: each should handle a distinct and unique portion of the dataset, and what is that portion of the dataset? 1/2 ? 1/4 ? 1/8 ?
  4. at the end of the fwd pass, rank#5 will allreduce with rank#1 only?
  5. what happens during the backward pass ?
1 Like