Hey i hope you are doing Great this weekend
i would like to ask you Please a Technical Question !!
i working on the CodeLLama Model which Uses a Decoder-Only Model Transformer following Arch Blow
Main Task is replaced Decoder-Only
which used Masked-Self-Attention and KV_cache with my own Encoder-Only
which used Diltaed-Attention used in LongNet
here the code Based on
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
model_id = "codellama/CodeLlama-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16
).to("cpu")
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32016, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32016, bias=False)
)
I planned to Replace the Block of LlamaDecoderLayer
following within Encoder-only
here the Origin Block Decoder-Only
used in CodeLlama
:
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
with my own using Inherent from base Class From Hugging Face Here my Following Process i did to Replace with Encoder-only
Step 1 : Inherent From LlamaConfig To adjust the new parameters config used in my own Encoder model which used Dilated Multi-heads Attention
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention , LlamaDecoderLayer , LlamaModel, LlamaForCausalLM
class CondensedLlamaConfig(LlamaConfig):
def __init__(
self,
dilation_rates=None,
segment_lengths=None,
is_causal=None,
**kwargs
):
super().__init__(**kwargs)
self.dilation_rates = dilation_rates
self.segment_lengths = segment_lengths
self.is_causal = is_causal
# Override the `to_dict` method to include the new parameters
def to_dict(self):
base_dict = super().to_dict()
config_dict = {
"dilation_rates": self.dilation_rates,
"segment_lengths": self.segment_lengths,
"is_causal": self.is_causal
}
base_dict.update(config_dict)
return base_dict
Output :
CondensedLlamaConfig {
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 1,
"dilation_rates": [
2048,
4096,
8192,
16384,
32768
],
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"is_causal": false,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pretraining_tp": 1,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000.0,
"segment_lengths": [
1,
2,
4,
6,
12
],
"tie_word_embeddings": false,
"transformers_version": "4.38.2",
"use_cache": true,
"vocab_size": 32000
}
Step 2 : the only part i wanted to Replace is self_attn
and my own Multi-head-Dilaed Attention is following is LongNet
based Mechanism following code Blow
Here the Dilated Attention
used flash_Attention_2
is Optional based on GPU used arch support A100
or T4 GPU
from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.modules.transformer import _get_activation_fn
import logging
import os
from functools import partial
from math import ceil
from timeit import Timer
from einops import rearrange
import plotly.graph_objects as go
import xformers.ops as xops
class DilatedAttention(nn.Module):
"""
DilatedAttention module implements dilated, scaled dot product attention with softmax.
Args:
segment_lengths (Sequence[int]): Lengths of segments for attention.
dilation_rates (Sequence[int]): Dilation rates for attention.
softmax_scale (Optional[float]): Temperature for softmax attention. Default is None.
attention_dropout (float): Dropout rate for attention. Default is 0.0.
op (Optional[xops.AttentionOp]): Attention operation. Default is None.
"""
def __init__(
self,
segment_lengths: Sequence[int],
dilation_rates: Sequence[int],
softmax_scale: Optional[float] = None,
attention_dropout: float = 0.0,
op: Optional[xops.AttentionOp] = None,
):
super().__init__()
if len(segment_lengths) != len(dilation_rates):
raise ValueError("segment_lengths and dilation_rates must have the same length")
self.segment_lengths = segment_lengths
self.dilation_rates = dilation_rates
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.op = op
def forward(self, query: Tensor, key: Tensor, value: Tensor, is_causal: bool = False) -> Tensor:
"""
Forward pass of the DilatedAttention module.
Args:
query (Tensor): Query tensor.
key (Tensor): Key tensor.
value (Tensor): Value tensor.
is_causal (bool): Flag indicating if the attention is causal. Default is False.
Returns:
Tensor: Output tensor.
"""
out = torch.zeros_like(query)
num_groups = len(self.dilation_rates)
group_sizes = [query.size(2) // num_groups] * num_groups
for i, (g, r, s) in enumerate(zip(group_sizes, self.dilation_rates, self.segment_lengths)):
q = rearrange(query, "b n h d -> b n h d")
k = rearrange(key, "b n h d -> b n h d")
v = rearrange(value, "b n h d -> b n h d")
attn_bias = xops.LowerTriangularMask() if is_causal else None
x = xops.memory_efficient_attention(
query=q, key=k, value=v, op=self.op, attn_bias=attn_bias
)
out += x
return out / num_groups
Here The Multi-head Dilated Attention
class MultiheadDilatedAttention(nn.Module):
"""
MultiheadDilatedAttention module implements a multi-head dilated attention mechanism.
Args:
embed_dim (int): The dimension of the input embeddings.
num_heads (int): Number of attention heads.
dilation_rates (Sequence[int]): Dilation rates for attention.
segment_lengths (Sequence[int]): Lengths of segments for attention.
dropout (float): Dropout rate for attention. Default is 0.0.
bias (bool): If True, enables bias in linear projections. Default is True.
layer_norm (bool): If True, applies layer normalization. Default is True.
layer_norm_eps (float): Epsilon value for layer normalization. Default is 1e-5.
gamma_init (float): Initialization value for gain in linear projections. Default is 1.0.
device (Optional[Union[torch.device, str]]): Device for parameters. Default is None.
dtype (Optional[torch.dtype]): Data type for parameters. Default is None.
op (Optional[xops.AttentionOp]): Attention operation. Default is None.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dilation_rates: Sequence[int],
segment_lengths: Sequence[int],
dropout: float = 0.0,
bias: bool = False,
layer_norm: bool = True,
layer_norm_eps: float = 1e-5,
gamma_init: float = 1.0,
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
op: Optional[xops.AttentionOp] = None,
):
super().__init__()
self.num_heads = num_heads
self.layer_norm = layer_norm
self.gamma_init = gamma_init
if not embed_dim % self.num_heads == 0:
raise ValueError(
f"embed_dim ({embed_dim}) must be divisible by "
f"num_heads ({num_heads})"
)
num_dilations = len(dilation_rates)
num_segments = len(segment_lengths)
if num_dilations != num_segments:
raise ValueError(
f"len(dilation_rates) ({num_dilations}) must be equal to "
f"len(segment_lengths) ({num_segments})"
)
print(num_heads)
print(embed_dim)
print(dilation_rates)
print(segment_lengths)
head_dim = embed_dim // num_heads
print(head_dim)
if not head_dim % 8 == 0:
raise ValueError(
f"head_dim (embed_dim / num_heads = {head_dim}) must be divisible by 8"
)
if not head_dim <= 128:
raise ValueError(
f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128"
)
self.q_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self.k_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self.v_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self.attention = DilatedAttention(
segment_lengths=segment_lengths,
dilation_rates=dilation_rates,
attention_dropout=dropout,
op=op,
)
self.norm: Optional[nn.LayerNorm] = None
if layer_norm:
self.norm = nn.LayerNorm(
embed_dim, eps=layer_norm_eps, device=device, dtype=dtype
)
self.o_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_normal_(self.q_proj.weight)
if self.q_proj.bias is not None:
nn.init.constant_(self.q_proj.bias, 0)
nn.init.xavier_normal_(self.k_proj.weight)
if self.k_proj.bias is not None:
nn.init.constant_(self.k_proj.bias, 0)
nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init)
if self.v_proj.bias is not None:
nn.init.constant_(self.v_proj.bias, 0)
nn.init.xavier_normal_(self.o_proj.weight, gain=self.gamma_init)
if self.o_proj.bias is not None:
nn.init.constant_(self.o_proj.bias, 0)
def forward(
self, query: Tensor, key: Tensor, value: Tensor, is_causal: bool = False
) -> Tuple[Tensor, None]:
"""
Forward pass of the MultiheadDilatedAttention module.
Args:
query (Tensor): Query tensor.
key (Tensor): Key tensor.
value (Tensor): Value tensor.
is_causal (bool): Flag indicating if the attention is causal. Default is False.
Returns:
Tuple[Tensor, None]: Output tensor and None.
"""
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q = rearrange(q, "b n (h d) -> b n h d", h=self.num_heads)
k = rearrange(k, "b n (h d) -> b n h d", h=self.num_heads)
v = rearrange(v, "b n (h d) -> b n h d", h=self.num_heads)
x = self.attention(q, k, v, is_causal=is_causal)
x = rearrange(x, "b n h d -> b n (h d)")
if self.layer_norm:
assert self.norm is not None
x = self.norm(x)
x = self.o_proj(x)
return x, None
To do so and Repalce the Layer used Inherent base Class from Hugging face
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaDecoderLayer
from transformers.modeling_utils import ModuleUtilsMixin
class CondensedLlamaAttention(LlamaAttention):
def __init__(self, config: CondensedLlamaConfig,layer_idx=None):
super().__init__(config)
self.LongNetAttention = MultiheadDilatedAttention(
config.hidden_size,
config.num_attention_heads,
config.dilation_rates,
config.segment_lengths
)
self.is_causal = config.is_causal
def forward(self, input, is_causal=None):
if is_causal is None:
is_causal = self.is_causal
x, _ = self.LongNetAttention(input, input, input, is_causal=is_causal)
return x
class CondensedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: CondensedLlamaConfig, layer_idx=None): # Add layer_idx as an argument
super().__init__(config, layer_idx=None) # Pass layer_idx to the parent class constructor
# Replace self_attn with your new attention module
self.self_attn = MultiheadDilatedAttention(
config.hidden_size,
config.num_attention_heads,
config.dilation_rates,
config.segment_lengths
)
self.is_causal = config.is_causal
def forward(self, input, is_causal=None):
if is_causal is None:
is_causal = self.is_causal
x, _ = self.LongNetAttention(input, input, input, is_causal=is_causal)
return x
class CondensedLlamaModel(LlamaModel):
def __init__(self, config: CondensedLlamaConfig):
super().__init__(config)
self.layers = nn.ModuleList([CondensedLlamaDecoderLayer(config,layer_idx=None) for _ in range(config.num_hidden_layers)])
# Initialize weights and apply final processing
self.post_init()
Notation: As long as is_causal=None
the learning of the Attention Mechanism is not masked which leads int Fully Learning Representation to produce the Embedding Space of Vectors of Tokens which means the Encoder-Only
learns the feature Representation relevant between Tokens attended to Druing Dot-Product Similarity instead of `Decoder-Only used Masked-Attention which I am not interested to use at the point
Step 4 : ReConstructed the Model using Adjustment Config Class
I did the following
Notation: i adjusted num_hidden_layers
only for show case config.num_hidden_layers = 2
the origin param is num_hidden_layers=32
config.num_hidden_layers = 2
model_1 = CondensedLlamaModel(config)
model_1
Notation: i didn’t use Rotary Embedding Because of Attention used is Linear
Q 1 Correct me Please if i need to keep Rotary Embedding
in my Encoder-Only
Output:
CondensedLlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-1): 2 x CondensedLlamaDecoderLayer(
(self_attn): MultiheadDilatedAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(attention): DilatedAttention()
(norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
Finally Step: Transfer Learning The Weights Layers following ["q_proj", "k_proj", "v_proj", "o_proj"]
From Decoder-Only
to `Encoder-Only``
Here Comparing the New Encoder-Only
with Decoder-Only
Decoder-Only used in CodeLlama
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
Encoder-Only used in CodeLlama with Adujsment i did
CondensedLlamaDecoderLayer(
(self_attn): MultiheadDilatedAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(attention): DilatedAttention()
(norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
both are has similar linear Layers in the following ["q_proj", "k_proj", "v_proj", "o_proj"]
the code i built to do Transfering the Weights
import torch
module_patterns_to_transfer = ["q_proj", "k_proj", "v_proj", "o_proj"]
def transfer_weights(original_model, custom_model, module_patterns_to_transfer):
original_dict = original_model.state_dict()
custom_dict = custom_model.state_dict()
# Filter and transfer weights for specified layers
for key in custom_dict.keys():
for pattern in module_patterns_to_transfer:
if pattern in key:
if key in original_dict:
# Transfer weights
with torch.no_grad():
custom_dict[key].copy_(original_dict[key])
# Load the updated state dictionary to the model
custom_model.load_state_dict(custom_dict)
# Transfer weights from the original model to the model
transfer_weights(model, model_1, module_patterns_to_transfer)
# transferred weights in the custom model
for key, parameter in model_1.state_dict().items():
print(key)
print(parameter.size())
print(parameter)
Output
embed_tokens.weight
torch.Size([32000, 4096])
tensor([[-0.0052, 0.0353, 0.0152, ..., -0.0285, 0.0035, 0.0149],
[ 0.0018, -0.0054, 0.0005, ..., 0.0048, 0.0319, 0.0018],
[ 0.0238, 0.0032, -0.0004, ..., 0.0171, -0.0069, -0.0232],
...,
[ 0.0084, -0.0174, 0.0109, ..., 0.0083, 0.0139, -0.0389],
[-0.0012, -0.0267, 0.0011, ..., 0.0287, 0.0102, -0.0176],
[ 0.0023, 0.0041, 0.0118, ..., 0.0253, 0.0198, -0.0259]])
layers.0.self_attn.q_proj.weight
torch.Size([4096, 4096])
tensor([[ 1.8845e-03, 7.0190e-04, -5.3406e-03, ..., 5.7373e-03,
5.5847e-03, 2.2650e-05],
[ 7.2937e-03, -5.8594e-03, 4.7607e-03, ..., -7.3242e-03,
-7.1106e-03, -9.9945e-04],
[-1.4282e-02, 6.2561e-03, 8.5831e-04, ..., 6.0120e-03,
9.8267e-03, 1.0986e-03],
...,
[ 1.9531e-02, -4.6692e-03, 1.1841e-02, ..., 1.6602e-02,
-1.3550e-02, 2.7847e-04],
[-1.2512e-02, 8.5449e-04, -6.8665e-03, ..., -2.1362e-02,
-2.0142e-02, -6.6528e-03],
[ 5.8289e-03, 3.7231e-03, 5.7068e-03, ..., 9.5215e-03,
7.0496e-03, -4.0588e-03]])
layers.0.self_attn.k_proj.weight
torch.Size([4096, 4096])
tensor([[ 1.4404e-02, 1.4221e-02, -2.3804e-03, ..., 4.3640e-03,
-1.1475e-02, -9.7046e-03],
[-3.0396e-02, -3.4485e-03, 4.4250e-03, ..., -8.4229e-03,
1.2390e-02, 1.2512e-02],
[ 1.0071e-03, -1.5747e-02, 1.7090e-03, ..., 9.8877e-03,
8.0109e-04, -8.6670e-03],
...,
[ 5.7373e-03, 4.3030e-03, 9.9945e-04, ..., -2.8839e-03,
4.0894e-03, 5.0964e-03],
[-3.6316e-03, 2.1057e-03, -5.7678e-03, ..., 4.1723e-07,
4.6082e-03, -1.1108e-02],
[ 2.7313e-03, 3.7231e-03, 1.5488e-03, ..., 2.7313e-03,
-9.8877e-03, 6.1035e-03]])
layers.0.self_attn.v_proj.weight
....
....
Please Correct me if missed understanding anything because i got bad feedback from CEO during this Process and i told him i was Correct and right to Transform the CodeLlama to be Encoder-Only to learn the Embedding
Thank you so much for advance