How does one reinitialize the weights of a Hugging Face LLaMA v2 model the official way as the original model?

I want to re-initialize the weights of a LLaMA v2 model I’m using/downloading. I went through all the docs and the source code from their HF code:

Tried the very simple test of going through the modules/params and reinitializing according to how their code suggests and print if the weights norm changed. It never changed, so idk if there is some mutation protection going on in pytorch HF models. Is there something I might be doing wrong?

import torch
from transformers import AutoModelForCausalLM, AutoConfig 
import torch.nn as nn

def main_reinit_model():
    """
    ref: https://stackoverflow.com/questions/76971761/how-to-adapt-llama-v2-model-to-less-than-7b-parameters
    ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L721
    ref: https://chat.openai.com/c/977d0cb0-b819-48ac-be5c-6e482ad5e518 
    """
    print('Starting to reinitialize the model...')
    # Load the pretrained LLaMA v2 config
    config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
    # print(f'config: {config} {type(config)}')
    # Print the original number of parameters 
    model = AutoModelForCausalLM.from_config(config) 
    # put model on device cuda
    model = model.to('cuda')
    # print the model's device
    print(f'{model.device=}')
    # print(f'{model=}')
    # print("Original number of parameters:", sum(p.numel() for p in model.parameters()))
    # go through all parameters and compute the l1 norm and sum it then print it
    norm_model = sum(p.norm(1) for p in model.parameters())
    # loop through modules of model and reinitialize weights with normal_mean, 0.02 
    print(f'{norm_model=}')
    """
    go through model and print all laters
    """
    # model.init_weights()  # didn't work
    # model._init_weights(module)  # didn't work needs module
    # for name, param in model.named_parameters():
    #     model._init_weights(param)
    # model.post_init()
    reinitialize_weights(model)
    # model._initialize_weights(module)  # didn't work needs module
    # for name, param in model.named_parameters():
    #     print(f'{name=} {param.shape=}')
    norm_model = sum(p.norm(1) for p in model.parameters())
    print(f'{norm_model=}')

def reinitialize_weights(model) -> None:
    for module in model.modules():
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0, std=0.02)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

def _init_weights(self, module):
    std = self.config.initializer_range
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=100.0, std=std)
        if module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=std)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()

def main_generate_smaller_model():
    """
    ref: https://stackoverflow.com/questions/76971761/how-to-adapt-llama-v2-model-to-less-than-7b-parameters
    """
    print('Starting to reinitialize the model...')
    # Load the pretrained LLaMA v2 config
    config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
    print(f'config: {config} {type(config)}')
    # Print the original number of parameters 
    model = AutoModelForCausalLM.from_config(config) 
    print("Original number of parameters:", sum(p.numel() for p in model.parameters()))

    # Modify the config to reduce size
    config.hidden_size = 2048
    config.num_hidden_layers = 12

    # Create new smaller model from modified config
    smaller_model = AutoModelForCausalLM.from_config(config)
    print("New number of parameters:", sum(p.numel() for p in smaller_model.parameters()))

if __name__ == '__main__':
    import time
    start = time.time()
    # main_generate_smaller_model() 
    main_reinit_model()
    print('Done!\a\a\a')

and the output never showed the weight norms changed:

Starting to reinitialize the model...
model.device=device(type='cuda', index=0)
norm_model=tensor(1.0779e+08, device='cuda:0', grad_fn=<AddBackward0>)
norm_model=tensor(1.0779e+08, device='cuda:0', grad_fn=<AddBackward0>)
Done!

What am I doing wrong? I just need to know how to relinitialize the weights in the proper/correct way according to llama. What exact init method and values to use?


Related

I have faced the same issue when replicating bigger better faster.

As I could understand, applying weights to a model only before you train it. I didn’t test if it is possible to overwrite initializations before training.

You must provide the specific layers you want to reinitialize to your model again, and you should initialize this layers before passing to the model.

model = AutoModelForCausalLM.from_config(config)

new_llama_block = LLaMa_Block() or something like that
new_llama_block.apply(_init_weights)

model.blocks[-1] = new_llama_block

I get that the values are different:

Starting to reinitialize the model...
-- NORM OF ENTIRE NET BEFORE REINITIALIZATION:
Total weight norm: 107792008.1875
-- NORM OF ENTIRE NET AFTER REINITIALIZATION:
Total weight norm: 3412126.090576172
Done!

but now I am wondering, what value of std to use…

code

"""
Original size of LLaMA v2 model: 7B parameters:
{
  "_name_or_path": "meta-llama/Llama-2-7b-hf",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 4096,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.31.0.dev0",
  "use_cache": true,
  "vocab_size": 32000
}

"""
import torch
from transformers import AutoModelForCausalLM, AutoConfig
import torch.nn as nn
from pathlib import Path
import datasets
from datasets import load_dataset, interleave_datasets
import torch
from transformers import GPT2LMHeadModel, PreTrainedTokenizer, AutoTokenizer, Trainer, TrainingArguments, AutoConfig
import math
import wandb
import os

def num_params(model: nn.Module) -> int:
    # print("Number of parameters:", sum(p.numel() for p in model.parameters()))
    return sum(p.numel() for p in model.parameters())

def get_weight_norms(model: nn.Module, verbose: bool = False) -> None:
    """
    Prints the L1 norm of the weights of each module in the given PyTorch model.

    Args:
    model (nn.Module): The PyTorch model whose weight norms are to be printed.

    Returns:
    None
    """
    total_weight_norm: float = 0.0
    for name, module in model.named_modules():
        # Check if the module has the 'weight' attribute
        if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
            # Calculate the L1 norm of the weights
            w_norm: float = module.weight.norm(1).item()
            total_weight_norm += w_norm
            if verbose:
                print(f"Norm of weights in module {name}: {w_norm}")
    return total_weight_norm

def reinitialize_weights(model, 
                         std: float = 0.0002,  # 0.02 ref: 
                         ) -> None:
    for module in model.modules():
        if isinstance(module, nn.Linear):
            # nn.init.normal_(module.weight, mean=0, std=0.02)
            nn.init.normal_(module.weight, mean=0, std=std)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

def get_microscopic_llama2(verbose: bool = True):
    raise NotImplementedError
    # return get_smaller_llama2(hidden_size=2, num_hidden_layers=3, verbose=verbose)

def get_deafult_smallest_llama2(verbose: bool = True):
    return get_smaller_llama2(hidden_size=32, num_hidden_layers=1, verbose=verbose)

def get_full_llama7b(gpu_idx: int = -1):
    config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype="auto")
    model = AutoModelForCausalLM.from_config(config)
    if gpu_idx >= 0:
        device = torch.device(f"cuda:{gpu_idx}" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
    return model

def get_smaller_llama2(hidden_size : int = 2048, 
                       num_hidden_layers : int = 12, 
                       verbose : bool = False,):
    config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
    config.hidden_size = hidden_size
    config.num_hidden_layers = num_hidden_layers
    model = AutoModelForCausalLM.from_config(config) 
    smaller_model = AutoModelForCausalLM.from_config(config)
    if verbose:
        print(f'config: {config}')
        print("Original number of parameters:", sum(p.numel() for p in model.parameters()))
    return smaller_model

def _test_generate_smaller_model():
    """
    ref: https://stackoverflow.com/questions/76971761/how-to-adapt-llama-v2-model-to-less-than-7b-parameters
    """
    print('Starting to generate a smaller model...')
    # Load the pretrained LLaMA v2 config
    config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
    print(f'config: {config} {type(config)}')
    print()
    # Print the original number of parameters 
    model = AutoModelForCausalLM.from_config(config) 
    print("Original number of parameters:", sum(p.numel() for p in model.parameters()))

    # Modify the config to reduce size
    config.hidden_size = 2048
    config.num_hidden_layers = 12

    # Create a new smaller model from the modified config
    smaller_model = AutoModelForCausalLM.from_config(config)
    print("New number of parameters:", sum(p.numel() for p in smaller_model.parameters()))

def _test_reinit_model():
    """ 
    export CUDA_VISIBLE_DEVICES=6
    """
    torch.cuda.empty_cache() 
    print('Starting to reinitialize the model...')
    
    # - Get smaller llama2 model
    # model = get_deafult_smallest_llama2()
    model = get_full_llama7b()
    device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    # - Check norm before reinitialization
    print("-- NORM OF ENTIRE NET BEFORE REINITIALIZATION:")
    total_weight_norm = get_weight_norms(model)
    print(f"Total weight norm: {total_weight_norm}")
    # - Reinitialize weights
    reinitialize_weights(model)
    print("-- NORM OF ENTIRE NET AFTER REINITIALIZATION:")
    total_weight_norm = get_weight_norms(model)
    print(f"Total weight norm: {total_weight_norm}")

if __name__ == '__main__':
    import time
    start = time.time()
    _test_reinit_model()
    print('Done!\a\a\a')

make sure to use do

model.to(torch.bfloat16)

if you want it to be in bflooat16

full code:

"""
Original size of LLaMA v2 model: 7B parameters:
{
  "_name_or_path": "meta-llama/Llama-2-7b-hf",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 4096,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.31.0.dev0",
  "use_cache": true,
  "vocab_size": 32000
}

"""
import torch
from transformers import AutoModelForCausalLM, AutoConfig
import torch.nn as nn
from pathlib import Path
import datasets
from datasets import load_dataset, interleave_datasets
import torch
from transformers import GPT2LMHeadModel, PreTrainedTokenizer, AutoTokenizer, Trainer, TrainingArguments, AutoConfig
# from transformers import models.llama.modeling_llama.LlamaRMSNorm as LlamaRMSNorm
import math
import wandb
import os

def num_params(model: nn.Module) -> int:
    # print("Number of parameters:", sum(p.numel() for p in model.parameters()))
    return sum(p.numel() for p in model.parameters())

def get_weight_norms(model: nn.Module, verbose: bool = False) -> None:
    """
    Prints the L1 norm of the weights of each module in the given PyTorch model.

    Args:
    model (nn.Module): The PyTorch model whose weight norms are to be printed.

    Returns:
    None
    """
    total_weight_norm: float = 0.0
    for name, module in model.named_modules():
        # Check if the module has the 'weight' attribute
        if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
            # Calculate the L1 norm of the weights
            w_norm: float = module.weight.norm(1).item()
            total_weight_norm += w_norm
            if verbose:
                print(f"Norm of weights in module {name}: {w_norm}")
    return total_weight_norm

# -- reinit (after you've created the new arch you want)

def reinitialize_weights(model, 
                         std: float = 0.0002,  # 0.02 ref: hailey S doesn't recommend this huge value! 
                         ) -> None:
    """
    
    From cs197, we choose std = 0.02 because of these two links:
    Why we chose 0.02 for standard deviation:
    https://github.com/huggingface/transformers/blob/772307be7649e1333a933cfaa229dc0dec2fd331/src/transformers/models/llama/modeling_llama.py#L858
    https://github.com/huggingface/transformers/blob/772307be7649e1333a933cfaa229dc0dec2fd331/src/transformers/models/llama/configuration_llama.py#L127
    Default is set to 0.02 in source code (see line 858 of the first link, and 127 of hte second link)
    """
    for module in model.modules():
        if isinstance(module, nn.Linear):
            # nn.init.normal_(module.weight, mean=0, std=0.02)
            nn.init.normal_(module.weight, mean=0, std=std)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

def reinitialize_weights_xavier(model):
    # """ Reinit with xavier """
    # for module in model.modules():
    #     if isinstance(module, nn.Linear):
    #         nn.init.xavier_normal_(module.weight)
    #         if module.bias is not None:
    #             nn.init.constant_(module.bias, 0)
    pass

def reinitialize_weights_kamming(model):
    """ 
    Reinit with kamming ref: https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_ 

    torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
    W ~ U(-bound, bound) = 3 * sqrt(3 / fan_mode)
    fan_modoe or mode = either 'fan_in' (default) or 'fan_out'. 
    Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. 
    Choosing 'fan_out' preserves the magnitudes in the backwards pass.
    recommended to use only with 'relu' or 'leaky_relu' (default).
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            nn.init.kaiming_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif 'norm' in name.lower() or 'norm' in str(module).lower():
            nn.init.constant_(module.weight, 1)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

def reinitialize_weights_gpt_neox_20B_inspired_4_llama2(model, 
                                                        L: int,  # for beyond scale we filled the data to block size which is 4096 for max seq length llama2
                                                        ):
    
    """
    Note: we nearly gpt-neox_20B (2022) & llama1 , llama2 (2019) does not say how they init

    I think gpt-neox_20B & llama2 both have pre-layernorm, because transformers without tears uses the init that gpt-neox-20B uses and llama1 says it uses prenorm,
    so maybe both pre-layernorm.
    Thus, I hope transformers without tears init/the same that transformers without tears uses works. 
    
    Init:
    FF layer: (as Wang 2021, not transformers without tears)
        -> W ~ N(0, 3/L * sqrt(D))
        decided that cuz 2021 is later than transformers without tears (2019 Nguyen, Salazer)
    Other layers (as transformers without tears(2019 Nguyen, Salazer)):
        -> W ~ N(0, sqrt(2 / (d + 4d)))
    norm_layer
        gpt-neox_20B: uses layer_norm
        llama2 uses llama1 which uses: RMSNorm (Zhang and Sennrich (2019))
        decided not to copy gpt-neox_20B (which uses W ~ N(0, sqrt(2 / (d + 4d)))) 
        because they don't share the same norm. llama1/2 use RMSnorm:
            mean_a_i = g_i * a_i / sqrt(1/n sum_j a_j^2 ) [where is eps?]
        So I decided
        -> g_i (gain) ~ constant(1)
        since we aren't training to completion so perhaps it helps at the beginning. If it diverges we can set this to small or what gpt-neox_20B uses.
        There is no offset, but I will set it to 0 in the code WLG.
    Activation:
        SwiGLU (not relu for llama1, llama2) [us for baby llama2]
        gpt-neox_20B uses...doesn't say.
    We use normal distribution because transformers without tears uses it & since gpt-neox_20B uses nearly same inits llama2 likely does too. 

    refs: rmsnorm https://arxiv.org/pdf/1910.07467.pdf
    refs: llama1 since llama2 uses same arch https://arxiv.org/pdf/2302.13971.pdf 
    ref: pytorch inits https://pytorch.org/docs/stable/nn.init.html

    ref: llama2 7b config: https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L13 
    ref: later https://discuss.huggingface.co/t/how-to-choose-std-for-weight-init-for-llama-2-after-reinitialize/69702

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 96, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=96, out_features=96, bias=False)
          (k_proj): Linear(in_features=96, out_features=96, bias=False)
          (v_proj): Linear(in_features=96, out_features=96, bias=False)
          (o_proj): Linear(in_features=96, out_features=96, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=96, out_features=11008, bias=False)
          (up_proj): Linear(in_features=96, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=96, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=96, out_features=32000, bias=False)
)
    return get_smaller_llama2(hidden_size=32*3, num_hidden_layers=32, verbose=verbose)
    so in_featres = 96 ==> D=96

-- NORM OF ENTIRE NET BEFORE REINITIALIZATION:
Total weight norm: 1742214.2911224365
-- NORM OF ENTIRE NET AFTER REINITIALIZATION:
Total weight norm: 19383.956434190273
Done!

some stds
7.47524945917718e-05
0.0035355339059327377
7.47524945917718e-05
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):  # all linear layers including MLP and attention, let's try this first given it's smaller
        # if 'gate_proj' == name or 'up_proj' == name or 'down_proj' == name or 'lm_head' == name:  # all FF/MLP layers (not attention)
            D = module.in_features  # I think this is right size it's xW []
            # L = module.weight.shape[1]  # I don't think you can get this from the module
            std = 3 / (L * (D)**0.5)
            nn.init.normal_(module.weight, mean=0, std=std)
            if module.bias is not None:  # don't think biases matter cuz bias=False in all layers
                nn.init.constant_(module.bias, 0)
        # elif isinstance(module, LlamaRMSNorm):
        # if name == 'norm' or name == 'input_layernorm' or name == 'post_attention_layernorm':
        #str(model.model.layers[0].input_layernorm)
        #'LlamaRMSNorm()'
        elif str(module) == 'LlamaRMSNorm()':
            if hasattr(module, 'weight'):
                if module.weight is not None:  # todo: idk if needed for layer norm
                    nn.init.constant_(module.weight, 1)
            if hasattr(module, 'bias'):  # I don't think RMSNorm has bias, the whole point it doesn't think centering matters so bias is similar to centering
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        else:  
            if hasattr(module, 'weight'):
                if module.weight is not None: 
                    D = module.weight.shape[0]
                    # L = module.weight.shape[1]  # I don't think you can get this from the module
                    std = (2 / (D + 4*D))**0.5  # e.g., small init attention layers
                    nn.init.normal_(module.weight, mean=0, std=std)
            if hasattr(module, 'bias'):
                if module.bias is not None:  # don't think biases matter cuz bias=False in all layers
                    nn.init.constant_(module.bias, 0)

# - get just the arch, then you have to reinit it

def get_microscopic_llama2(verbose: bool = True):
    raise NotImplementedError
    # return get_smaller_llama2(hidden_size=2, num_hidden_layers=3, verbose=verbose)

def _get_deafult_smallest_llama2_debugging(verbose: bool = True):
    return get_smaller_llama2(hidden_size=32, num_hidden_layers=1, verbose=verbose)

def get_deafult_smallest_baby_llama2_v1_36m_0p036b(verbose: bool = False):
    """ 
    with hps: 
        hidden_size=32, num_hidden_layers=32
    num_params = 35_997_728

    Starting to reinitialize the model...
    Original number of parameters: 35997728
    -- NORM OF ENTIRE NET BEFORE REINITIALIZATION:
    Total weight norm (before): 576430.1846704483
    -- NORM OF ENTIRE NET AFTER REINITIALIZATION:
    Total weight norm (after): total_weight_norm_after_reinit=7483.21137085557
    """
    print('Warning: you might need to reinit the weights if your using baby llama2.')
    return get_smaller_llama2(hidden_size=32, num_hidden_layers=32, verbose=verbose)

def get_deafult_smallest_baby_llama2_v2(verbose: bool = False):
    return get_smaller_llama2(hidden_size=32*3, num_hidden_layers=32, verbose=verbose)

def get_full_llama7b(gpu_idx: int = -1):
    config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype="auto")
    model = AutoModelForCausalLM.from_config(config)
    # model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", trust_remote_code=True, torch_dtype=torch.bfloat16, use_auth_token=True,)
    if gpu_idx >= 0:
        device = torch.device(f"cuda:{gpu_idx}" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability(torch.cuda.current_device())[0] >= 8 else torch.float32 # if >= 8 ==> brain float 16 available or set to True if you always want fp32
        model = model.to(torch_dtype)
    return model

def get_smaller_llama2(hidden_size : int = 2048, 
                       num_hidden_layers : int = 12, 
                       return_tokenizer: bool = False, 
                       verbose : bool = False,
                       ):
    config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
    config.hidden_size = hidden_size
    config.num_hidden_layers = num_hidden_layers
    smaller_model = AutoModelForCausalLM.from_config(config)
    # NOTE: putting torch_dtype in the config doesn't work, so you have to move the model to bfloat16 later with model.to(torch.bfloat16)
    torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability(torch.cuda.current_device())[0] >= 8 else torch.float32 # if >= 8 ==> brain float 16 available or set to True if you always want fp32
    device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
    smaller_model = smaller_model.to(device)
    smaller_model = smaller_model.to(torch_dtype)
    print(f'Model is currently on: {next(iter(smaller_model.parameters())).dtype=}')
    if verbose:
        print(f'config: {config}')
        print("Smaller number of parameters:", sum(p.numel() for p in smaller_model.parameters()))
        print(f'Model is currently on: {next(iter(smaller_model.parameters())).device=}')
        print(f'Model is currently on: {next(iter(smaller_model.parameters())).dtype=}')
        print()
    if return_tokenizer:
        tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', padding_side="right", use_fast=False, trust_remote_code=True, use_auth_token=True)
        return smaller_model, tokenizer
    return smaller_model

# def _test_generate_smaller_model():
#     """
#     ref: https://stackoverflow.com/questions/76971761/how-to-adapt-llama-v2-model-to-less-than-7b-parameters
#     """
#     print('Starting to generate a smaller model...')
#     # Load the pretrained LLaMA v2 config
#     config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
#     print(f'config: {config} {type(config)}')
#     print()
#     # Print the original number of parameters 
#     model = AutoModelForCausalLM.from_config(config) 
#     print("Original number of parameters:", sum(p.numel() for p in model.parameters()))

#     # Modify the config to reduce size
#     config.hidden_size = 2048
#     config.num_hidden_layers = 12

#     # Create a new smaller model from the modified config
#     smaller_model = AutoModelForCausalLM.from_config(config)
#     print("New number of parameters:", sum(p.numel() for p in smaller_model.parameters()))

def _test_reinit_model():
    """ 
export CUDA_VISIBLE_DEVICES=6
    """
    torch.cuda.empty_cache() 
    print('Starting to reinitialize the model...')
    
    # - Get smaller llama2 model
    # model = get_deafult_smallest_llama2()
    model = get_deafult_smallest_baby_llama2_v1_36m_0p036b()
    # model = get_deafult_smallest_baby_llama2_v2()
    # model = get_full_llama7b()
    device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    # - sanity checks
    print(f'Model is currently on: {next(iter(model.parameters())).device=}')
    print(f'Model is currently on: {next(iter(model.parameters())).dtype=}')
    print("Original number of parameters:", sum(p.numel() for p in model.parameters()))
    # - Check norm before reinitialization
    print("-- NORM OF ENTIRE NET BEFORE REINITIALIZATION:")
    total_weight_norm_before_reinit = get_weight_norms(model)
    print(f"Total weight norm (before): {total_weight_norm_before_reinit}")
    # - Reinitialize weights
    # reinitialize_weights(model)
    # reinitialize_weights_kamming(model)
    reinitialize_weights_gpt_neox_20B_inspired_4_llama2(model, L=4096)
    print("-- NORM OF ENTIRE NET AFTER REINITIALIZATION:")
    total_weight_norm_after_reinit = get_weight_norms(model)
    print(f"Total weight norm (after): {total_weight_norm_after_reinit=}")
    assert total_weight_norm_before_reinit != total_weight_norm_after_reinit, f'Error: total_weight_norm_reinit == total_weight_norm' 
    assert total_weight_norm_before_reinit > total_weight_norm_after_reinit, f'Error norm before reinit < norm after reinit (should be smaller after reinit).'

if __name__ == '__main__':
    import time
    start = time.time()
    print()
    _test_reinit_model()
    print('Done!\a\a\a')

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.