I want to create a modified Llama model, which is the same as the original Llama except that I want to replace some of the FFNs with Mixture of Experts. The new model has a different name from Llama, such as “BaseLlama”. I pay most attention on the MoE architecture and don’t care about the other techniques and corresponding code of transformers.LlamaModel. I prefer to not modify my code when the transformers library make some minor modifications on the code of llama model, so I try to inherit as many classes related to the Llama model as I can.
I have a piece of experimental code here. Is it safe and efficient? Are there any bugs? Is there a better and more elegant way to implement it?
from transformers.models.llama.configuration_llama import *
from transformers.models.llama.modeling_llama import *
class BaseLlamaConfig(LlamaConfig):
model_type = "base_llama"
def __init__(self, sparse_step=2, num_experts=8, **kwargs):
self.sparse_step = sparse_step
self.num_experts = num_experts
super().__init__(**kwargs)
class BaseLlamaSparseMLP(torch.nn.Module):
def __init__(self, config: BaseLlamaConfig):
super().__init__()
self.classifier = torch.nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = torch.nn.ModuleList([LlamaMLP(config) for _ in range(config.num_experts)])
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# ignore detailed code
return hidden_states
class BaseLlamaLayer(LlamaDecoderLayer):
def __init__(self, config: BaseLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
if layer_idx % config.sparse_step == config.sparse_step - 1:
self.mlp = BaseLlamaSparseMLP(config)
class BaseLlamaPreTrainedModel(LlamaPreTrainedModel):
config_class = BaseLlamaConfig
_no_split_modules = ["BaseLlamaLayer"]
class BaseLlamaModel(BaseLlamaPreTrainedModel, LlamaModel):
def __init__(self, config: BaseLlamaConfig):
super().__init__(config)
self.layers = nn.ModuleList([BaseLlamaLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.post_init()
class BaseLlamaForCausalLM(BaseLlamaPreTrainedModel, LlamaForCausalLM):
def __init__(self, config):
super().__init__(config)
self.model = BaseLlamaModel(config)
self.post_init()