I’m writing a custom versions of LlamaModels, and for one of those approaches I want to overwrite the attention mechanism of each layer. My code looks like this. Note that even when I define LlamaAttentionHybrid (a subclass of LlamaAttention) to be the exact same as LlamaAttention, I still get hallucination issues. This suggest I’m not correctly replacing the attention mechanism.
class LlamaHybridForCausalLM(LlamaForCausalLM):
def __init__(self, config: LlamaHybridConfig):
super().__init__(config)
if config.hybrid:
for i, layer in enumerate(self.model.layers):
# Need to also copy attention weights
old_attn = layer.self_attn
layer.self_attn = LlamaAttentionHybrid(config, i)
layer.self_attn.load_state_dict(old_attn.state_dict())
However, the model works completely fine when I write this code:
class LlamaHybridForCausalLM(LlamaForCausalLM):
def __init__(self, config: LlamaHybridConfig):
super().__init__(config)
if config.hybrid:
for i, layer in enumerate(self.model.layers):
# Need to also copy attention weights
old_attn = layer.self_attn
layer.self_attn = LlamaAttention(config, i)
layer.self_attn.load_state_dict(old_attn.state_dict())
Why would this happen even when in the subclass i don’t make any changes? Note, that the forward function here is defined exactly the same as the source code.
class LlamaAttentionHybrid(LlamaAttention):
def __init__(self, config: LlamaHybridConfig, layer_idx: int):
super().__init__(config, layer_idx)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
Thanks!
EDIT: I narrowed the issue down to the redefining of the forward function. For some reason when I add the forward function into the subclass even if it’s identical, the model hallucinates dramatically.