Hi @sgugger ,
I recently finished pre-training a series of PaLM models (150m, 410, 1B) on C4. I am attempting to upload them as custom models but I am currently having difficulty. I am having an issue where the layers are not being initialized properly. The pytorch_model.bin
weights and files for 1B are stored here: conceptofmind/palm-1b · Hugging Face
I defined the files I am using below.
config.json
:
{
"architectures": [
"PaLM"
],
"auto_map": {
"AutoConfig": "configuration_palm.PalmConfig",
"AutoModel": "modeling_palm.PaLMForCausalLM"
},
"dim": 2048,
"num_tokens": 50304,
"depth": 16,
"causal": "True",
"dim_head": 128,
"heads": 8,
"ff_mult": 4,
"attn_dropout": 0.0,
"ff_dropout": 0.0,
"rotary_xpos_scale_base": 512,
"flash_attn": "True",
"cross_entropy_ignore_index": 0,
"model_type": "palm",
"torch_dtype": "bfloat16",
"transformers_version": "4.28.1"
}
configuration_palm.py
from transformers import PretrainedConfig
class PalmConfig(PretrainedConfig):
model_type = "palm"
def __init__(
self,
dim=2048,
num_tokens=50304,
depth=16,
causal = True,
dim_head = 128,
heads = 8,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
rotary_xpos_scale_base = 512,
flash_attn = True,
cross_entropy_ignore_index = 0,
**kwargs,
):
self.dim = dim
self.num_tokens = num_tokens
self.depth = depth
self.causal = causal
self.dim_head = dim_head
self.heads = heads
self.ff_mult = ff_mult
self.attn_dropout = attn_dropout
self.ff_dropout = ff_dropout
self.rotary_xpos_scale_base = rotary_xpos_scale_base
self.flash_attn = flash_attn
self.cross_entropy_ignore_index = cross_entropy_ignore_index
super().__init__(**kwargs)
modeling_palm.py
import torch
from transformers import PreTrainedModel
from palm_rlhf_pytorch import PaLM
from .configuration_palm import PalmConfig
class PaLMForCausalLM(PreTrainedModel):
config_class = PalmConfig
def __init__(self, config):
super().__init__(config)
self.model = PaLM(
dim=config.dim,
num_tokens=config.num_tokens,
depth=config.depth,
causal=config.causal,
dim_head=config.dim_head,
heads=config.heads,
ff_mult=config.ff_mult,
attn_dropout=config.attn_dropout,
ff_dropout=config.ff_dropout,
rotary_xpos_scale_base=config.rotary_xpos_scale_base,
flash_attn=config.flash_attn,
cross_entropy_ignore_index=config.cross_entropy_ignore_index
)
def forward(self, tensor):
loss = self.model(tensor, return_loss=True)
return {"loss": loss}
def _init_weights(self, module):
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=1.0)
if isinstance(module, torch.nn.Linear) and module.bias is not None:
module.bias.data.zero_()
The PaLM model follows this structure:
PaLM(
(token_emb): Embedding(50304, 2048)
(layers): ModuleList(
(0-15): 16 x Residual(
(fn): ParallelTransformerBlock(
(norm): LayerNorm()
(attend): Attention(
(attn_dropout): Dropout(p=0.0, inplace=False)
)
(rotary_emb): RotaryEmbedding()
(fused_attn_ff_proj): Linear(in_features=2048, out_features=17664, bias=False)
(attn_out): Linear(in_features=1024, out_features=2048, bias=False)
(attn_dropout): Dropout(p=0.0, inplace=False)
(ff_out): Sequential(
(0): SwiGLU()
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=8192, out_features=2048, bias=False)
)
)
)
)
(norm): LayerNorm()
(to_logits): Linear(in_features=2048, out_features=50304, bias=False)
(finetune_modules): ModuleDict()
)
The layers are:
token_emb.weight
layers.0.fn.norm.gamma
layers.0.fn.fused_attn_ff_proj.weight
layers.0.fn.attn_out.weight
layers.0.fn.ff_out.2.weight
layers.1.fn.norm.gamma
layers.1.fn.fused_attn_ff_proj.weight
layers.1.fn.attn_out.weight
layers.1.fn.ff_out.2.weight
layers.2.fn.norm.gamma
layers.2.fn.fused_attn_ff_proj.weight
layers.2.fn.attn_out.weight
layers.2.fn.ff_out.2.weight
layers.3.fn.norm.gamma
layers.3.fn.fused_attn_ff_proj.weight
layers.3.fn.attn_out.weight
layers.3.fn.ff_out.2.weight
layers.4.fn.norm.gamma
layers.4.fn.fused_attn_ff_proj.weight
layers.4.fn.attn_out.weight
layers.4.fn.ff_out.2.weight
layers.5.fn.norm.gamma
layers.5.fn.fused_attn_ff_proj.weight
layers.5.fn.attn_out.weight
layers.5.fn.ff_out.2.weight
layers.6.fn.norm.gamma
layers.6.fn.fused_attn_ff_proj.weight
layers.6.fn.attn_out.weight
layers.6.fn.ff_out.2.weight
layers.7.fn.norm.gamma
layers.7.fn.fused_attn_ff_proj.weight
layers.7.fn.attn_out.weight
layers.7.fn.ff_out.2.weight
layers.8.fn.norm.gamma
layers.8.fn.fused_attn_ff_proj.weight
layers.8.fn.attn_out.weight
layers.8.fn.ff_out.2.weight
layers.9.fn.norm.gamma
layers.9.fn.fused_attn_ff_proj.weight
layers.9.fn.attn_out.weight
layers.9.fn.ff_out.2.weight
layers.10.fn.norm.gamma
layers.10.fn.fused_attn_ff_proj.weight
layers.10.fn.attn_out.weight
layers.10.fn.ff_out.2.weight
layers.11.fn.norm.gamma
layers.11.fn.fused_attn_ff_proj.weight
layers.11.fn.attn_out.weight
layers.11.fn.ff_out.2.weight
layers.12.fn.norm.gamma
layers.12.fn.fused_attn_ff_proj.weight
layers.12.fn.attn_out.weight
layers.12.fn.ff_out.2.weight
layers.13.fn.norm.gamma
layers.13.fn.fused_attn_ff_proj.weight
layers.13.fn.attn_out.weight
layers.13.fn.ff_out.2.weight
layers.14.fn.norm.gamma
layers.14.fn.fused_attn_ff_proj.weight
layers.14.fn.attn_out.weight
layers.14.fn.ff_out.2.weight
layers.15.fn.norm.gamma
layers.15.fn.fused_attn_ff_proj.weight
layers.15.fn.attn_out.weight
layers.15.fn.ff_out.2.weight
norm.gamma
Due to the way I am importing PaLM and defining the model variable, there is going to be a name mismatch. I tried to resolve this by changing the names of the layers in the pytorch_model.bin
state dict:
import torch
from palm_rlhf_pytorch import PaLM
model = PaLM(
num_tokens=50304, dim=2048, depth=16, dim_head=128, heads=8
).cuda()
model_path = "/pytorch_model.bin"
state_dict = torch.load(model_path)
new_state_dict = {}
for key, value in state_dict.items():
new_key = 'model.' + key
new_key = new_key.replace('gamma', 'weight')
new_key = new_key.replace('beta', 'bias')
new_state_dict[new_key] = value
new_model_path = "/pytorch_model.bin"
torch.save(new_state_dict, new_model_path)
Renaming all the layers to match did not work and still throws an initialization error.
Do you have any input on how I should be loading the PyTorch model and setting up the PreTrainedModel wrapper?
Is there just a way to import the model and load weights to be used in the forward method while still being compatible with the Trainer?
Thank you,
Enrico Shippole