I am checking out Octavius right now and they introduce network level routing with mixture of experts. Their router takes in the network input and calculates expert scores which are then passed down to the target modules, e.g. q_proj and v_proj.
My problem is that I don’t understand how this can work with gradient checkpointing enabled because I thought gradient checkpointing encapsules the decoder blocks. Thus, the scores which not passed down via forward but directly added as instance variables q_proj.scores = ...
are not taken into account during the backward.
Is it possible to prevent his somehow?