PaLM (150m, 410, 1B) - Uploading custom models

Hi @sgugger ,

I recently finished pre-training a series of PaLM models (150m, 410, 1B) on C4. I am attempting to upload them as custom models but I am currently having difficulty. I am having an issue where the layers are not being initialized properly. The pytorch_model.bin weights and files for 1B are stored here: conceptofmind/palm-1b · Hugging Face

I defined the files I am using below.

config.json:

{
  "architectures": [
    "PaLM"
  ],
  "auto_map": {
    "AutoConfig": "configuration_palm.PalmConfig",
    "AutoModel": "modeling_palm.PaLMForCausalLM"
  },
  "dim": 2048,
  "num_tokens": 50304,
  "depth": 16,
  "causal": "True",
  "dim_head": 128,
  "heads": 8,
  "ff_mult": 4,
  "attn_dropout": 0.0,
  "ff_dropout": 0.0,
  "rotary_xpos_scale_base": 512,
  "flash_attn": "True",
  "cross_entropy_ignore_index": 0,
  "model_type": "palm",
  "torch_dtype": "bfloat16",
  "transformers_version": "4.28.1"
}

configuration_palm.py

from transformers import PretrainedConfig

class PalmConfig(PretrainedConfig):
    model_type = "palm"

    def __init__(
        self,
        dim=2048,
        num_tokens=50304,
        depth=16,
        causal = True,
        dim_head = 128,
        heads = 8,
        ff_mult = 4,
        attn_dropout = 0.,
        ff_dropout = 0.,
        rotary_xpos_scale_base = 512,
        flash_attn = True,
        cross_entropy_ignore_index = 0,
        **kwargs,
    ):
        self.dim = dim
        self.num_tokens = num_tokens
        self.depth = depth
        self.causal = causal
        self.dim_head = dim_head
        self.heads = heads
        self.ff_mult = ff_mult
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
        self.rotary_xpos_scale_base = rotary_xpos_scale_base
        self.flash_attn = flash_attn
        self.cross_entropy_ignore_index = cross_entropy_ignore_index
        super().__init__(**kwargs)

modeling_palm.py

import torch
from transformers import PreTrainedModel
from palm_rlhf_pytorch import PaLM
from .configuration_palm import PalmConfig

class PaLMForCausalLM(PreTrainedModel):
    config_class = PalmConfig

    def __init__(self, config):
        super().__init__(config)

        self.model = PaLM(
            dim=config.dim,
            num_tokens=config.num_tokens,
            depth=config.depth,
            causal=config.causal,
            dim_head=config.dim_head,
            heads=config.heads,
            ff_mult=config.ff_mult,
            attn_dropout=config.attn_dropout,
            ff_dropout=config.ff_dropout,
            rotary_xpos_scale_base=config.rotary_xpos_scale_base,
            flash_attn=config.flash_attn,
            cross_entropy_ignore_index=config.cross_entropy_ignore_index
        )

    def forward(self, tensor):
        loss = self.model(tensor, return_loss=True)
        return {"loss": loss}

    def _init_weights(self, module):
        if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=1.0)
        if isinstance(module, torch.nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

The PaLM model follows this structure:

PaLM(
  (token_emb): Embedding(50304, 2048)
  (layers): ModuleList(
    (0-15): 16 x Residual(
      (fn): ParallelTransformerBlock(
        (norm): LayerNorm()
        (attend): Attention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (rotary_emb): RotaryEmbedding()
        (fused_attn_ff_proj): Linear(in_features=2048, out_features=17664, bias=False)
        (attn_out): Linear(in_features=1024, out_features=2048, bias=False)
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (ff_out): Sequential(
          (0): SwiGLU()
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=8192, out_features=2048, bias=False)
        )
      )
    )
  )
  (norm): LayerNorm()
  (to_logits): Linear(in_features=2048, out_features=50304, bias=False)
  (finetune_modules): ModuleDict()
)

The layers are:

token_emb.weight
layers.0.fn.norm.gamma
layers.0.fn.fused_attn_ff_proj.weight
layers.0.fn.attn_out.weight
layers.0.fn.ff_out.2.weight
layers.1.fn.norm.gamma
layers.1.fn.fused_attn_ff_proj.weight
layers.1.fn.attn_out.weight
layers.1.fn.ff_out.2.weight
layers.2.fn.norm.gamma
layers.2.fn.fused_attn_ff_proj.weight
layers.2.fn.attn_out.weight
layers.2.fn.ff_out.2.weight
layers.3.fn.norm.gamma
layers.3.fn.fused_attn_ff_proj.weight
layers.3.fn.attn_out.weight
layers.3.fn.ff_out.2.weight
layers.4.fn.norm.gamma
layers.4.fn.fused_attn_ff_proj.weight
layers.4.fn.attn_out.weight
layers.4.fn.ff_out.2.weight
layers.5.fn.norm.gamma
layers.5.fn.fused_attn_ff_proj.weight
layers.5.fn.attn_out.weight
layers.5.fn.ff_out.2.weight
layers.6.fn.norm.gamma
layers.6.fn.fused_attn_ff_proj.weight
layers.6.fn.attn_out.weight
layers.6.fn.ff_out.2.weight
layers.7.fn.norm.gamma
layers.7.fn.fused_attn_ff_proj.weight
layers.7.fn.attn_out.weight
layers.7.fn.ff_out.2.weight
layers.8.fn.norm.gamma
layers.8.fn.fused_attn_ff_proj.weight
layers.8.fn.attn_out.weight
layers.8.fn.ff_out.2.weight
layers.9.fn.norm.gamma
layers.9.fn.fused_attn_ff_proj.weight
layers.9.fn.attn_out.weight
layers.9.fn.ff_out.2.weight
layers.10.fn.norm.gamma
layers.10.fn.fused_attn_ff_proj.weight
layers.10.fn.attn_out.weight
layers.10.fn.ff_out.2.weight
layers.11.fn.norm.gamma
layers.11.fn.fused_attn_ff_proj.weight
layers.11.fn.attn_out.weight
layers.11.fn.ff_out.2.weight
layers.12.fn.norm.gamma
layers.12.fn.fused_attn_ff_proj.weight
layers.12.fn.attn_out.weight
layers.12.fn.ff_out.2.weight
layers.13.fn.norm.gamma
layers.13.fn.fused_attn_ff_proj.weight
layers.13.fn.attn_out.weight
layers.13.fn.ff_out.2.weight
layers.14.fn.norm.gamma
layers.14.fn.fused_attn_ff_proj.weight
layers.14.fn.attn_out.weight
layers.14.fn.ff_out.2.weight
layers.15.fn.norm.gamma
layers.15.fn.fused_attn_ff_proj.weight
layers.15.fn.attn_out.weight
layers.15.fn.ff_out.2.weight
norm.gamma

Due to the way I am importing PaLM and defining the model variable, there is going to be a name mismatch. I tried to resolve this by changing the names of the layers in the pytorch_model.bin state dict:

import torch
from palm_rlhf_pytorch import PaLM

model = PaLM(
    num_tokens=50304, dim=2048, depth=16, dim_head=128, heads=8
).cuda()

model_path = "/pytorch_model.bin"

state_dict = torch.load(model_path)
new_state_dict = {}
for key, value in state_dict.items():
    new_key = 'model.' + key
    new_key = new_key.replace('gamma', 'weight')
    new_key = new_key.replace('beta', 'bias')
    new_state_dict[new_key] = value

new_model_path = "/pytorch_model.bin"
torch.save(new_state_dict, new_model_path)

Renaming all the layers to match did not work and still throws an initialization error.

Do you have any input on how I should be loading the PyTorch model and setting up the PreTrainedModel wrapper?

Is there just a way to import the model and load weights to be used in the forward method while still being compatible with the Trainer?

Thank you,

Enrico Shippole