The best way to modify a transformers model with minimal modifications

I want to create a modified Llama model, which is the same as the original Llama except that I want to replace some of the FFNs with Mixture of Experts. The new model has a different name from Llama, such as “BaseLlama”. I pay most attention on the MoE architecture and don’t care about the other techniques and corresponding code of transformers.LlamaModel. I prefer to not modify my code when the transformers library make some minor modifications on the code of llama model, so I try to inherit as many classes related to the Llama model as I can.
I have a piece of experimental code here. Is it safe and efficient? Are there any bugs? Is there a better and more elegant way to implement it?

from transformers.models.llama.configuration_llama import *
from transformers.models.llama.modeling_llama import *

class BaseLlamaConfig(LlamaConfig):
    model_type = "base_llama"

    def __init__(self, sparse_step=2, num_experts=8, **kwargs):
        self.sparse_step = sparse_step
        self.num_experts = num_experts
        super().__init__(**kwargs)

class BaseLlamaSparseMLP(torch.nn.Module):
    def __init__(self, config: BaseLlamaConfig):
        super().__init__()
        self.classifier = torch.nn.Linear(config.hidden_size, config.num_experts, bias=False)
        self.experts = torch.nn.ModuleList([LlamaMLP(config) for _ in range(config.num_experts)])

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # ignore detailed code
        return hidden_states

class BaseLlamaLayer(LlamaDecoderLayer):
    def __init__(self, config: BaseLlamaConfig, layer_idx: int):
        super().__init__(config, layer_idx)

        if layer_idx % config.sparse_step == config.sparse_step - 1:
            self.mlp = BaseLlamaSparseMLP(config)

class BaseLlamaPreTrainedModel(LlamaPreTrainedModel):
    config_class = BaseLlamaConfig
    _no_split_modules = ["BaseLlamaLayer"]

class BaseLlamaModel(BaseLlamaPreTrainedModel, LlamaModel):
    def __init__(self, config: BaseLlamaConfig):
        super().__init__(config)
        self.layers = nn.ModuleList([BaseLlamaLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
        self.post_init()

class BaseLlamaForCausalLM(BaseLlamaPreTrainedModel, LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = BaseLlamaModel(config)
        self.post_init()
1 Like