Help with Sparse LLM Implementation

Hello,

I have this sparse llm implementation based on gpt2 and it runs when using Trainer, however, I always get 0.000 training loss and then the outputs from the model are always nan, how can I fix it? `import torch
import torch.nn as nn
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

class PositionalEncoding(nn.Module):
def init(self, d_model, max_seq_length=5000, dropout=0.1):
super(PositionalEncoding, self).init()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer(‘pe’, pe)

def forward(self, x):
    x = x + self.pe[:, :x.size(1), :]
    return self.dropout(x)

class SparseAttention(nn.Module):
def init(self, d_model, nhead, sparsity=0.1):
super(SparseAttention, self).init()
self.d_model = d_model
self.nhead = nhead
self.sparsity = sparsity
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

def forward(self, x, mask=None):
    batch_size, seq_length, _ = x.size()
    q = self.W_q(x).view(batch_size, seq_length, self.nhead, self.d_model // self.nhead).transpose(1, 2)
    k = self.W_k(x).view(batch_size, seq_length, self.nhead, self.d_model // self.nhead).transpose(1, 2)
    v = self.W_v(x).view(batch_size, seq_length, self.nhead, self.d_model // self.nhead).transpose(1, 2)

    scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_model // self.nhead) ** 0.5
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Apply sparsity
    top_k = int(self.sparsity * seq_length)
    _, top_indices = torch.topk(scores, top_k, dim=-1)
    attention_mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, top_indices, True)
    attention = torch.where(attention_mask, scores, torch.tensor(-float('inf')).to(scores.device))

    attention = nn.functional.softmax(attention, dim=-1)
    x = torch.matmul(attention, v).transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
    x = self.W_o(x)
    return x

from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

class SparseGPTModel(GPT2LMHeadModel):
def init(self, config, sparsity=0.1):
super().init(config)
self.sparsity = sparsity
self.transformer = SparseGPTTransformer(config, sparsity)

def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None):
    transformer_outputs = self.transformer(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        past_key_values=past_key_values,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    hidden_states = transformer_outputs[0]
    lm_logits = self.lm_head(hidden_states)

    if return_dict:
        output = (lm_logits,) + transformer_outputs[1:]
        return_dict = self.config.use_return_dict
        return CausalLMOutputWithCrossAttentions(
            loss=None,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )
    else:
        output = (lm_logits,) + transformer_outputs[1:]

    if labels is not None:
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        output = (loss,) + output

    return output

class SparseGPTTransformer(nn.Module):
def init(self, config, sparsity=0.1):
super().init()
self.config = config
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([SparseGPTBlock(config, sparsity) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

def get_input_embeddings(self):
    return self.wte

def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None):
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    if input_ids is not None and inputs_embeds is not None:
        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
    elif input_ids is not None:
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])
    elif inputs_embeds is not None:
        input_shape = inputs_embeds.size()[:-1]
    else:
        raise ValueError("You have to specify either input_ids or inputs_embeds")

    if position_ids is None:
        position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand(input_shape)

    if attention_mask is not None:
        attention_mask = attention_mask.view(-1, input_shape[-1])
        attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)
        attention_mask = (1.0 - attention_mask) * -10000.0

    if token_type_ids is not None:
        token_type_ids = token_type_ids.view(-1, input_shape[-1])
        token_type_embeds = self.wte(token_type_ids)
    else:
        token_type_embeds = 0

    if inputs_embeds is None:
        inputs_embeds = self.wte(input_ids)

    position_embeds = self.wpe(position_ids)
    hidden_states = inputs_embeds + position_embeds + token_type_embeds

    hidden_states = self.dropout(hidden_states)

    output_shape = input_shape + (hidden_states.size(-1),)

    presents = () if use_cache else None
    all_self_attentions = () if output_attentions else None
    all_hidden_states = () if output_hidden_states else None
    for i, block in enumerate(self.h):
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)
        outputs = block(hidden_states, attention_mask, past_key_values[i] if past_key_values is not None else None, use_cache=use_cache, output_attentions=output_attentions)
        hidden_states = outputs[0]
        if use_cache:
            presents = presents + (outputs[1],)

        if output_attentions:
            all_self_attentions = all_self_attentions + (outputs[2],)

    hidden_states = self.ln_f(hidden_states)

    hidden_states = hidden_states.view(*output_shape)
    if return_dict:
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=None,
        )
    else:
        outputs = (hidden_states,)
        if use_cache:
            outputs = outputs + (presents,)
        if output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if output_attentions:
            outputs = outputs + (all_self_attentions,)
        return outputs

class SparseGPTBlock(nn.Module):
def init(self, config, sparsity=0.1):
super().init()
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.attn = SparseAttention(config.hidden_size, config.num_attention_heads, sparsity)
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.n_inner),
nn.GELU(),
nn.Linear(config.n_inner, config.hidden_size),
)
self._init_weights()

def _init_weights(self):
    nn.init.normal_(self.mlp[0].weight, std=0.05)
    nn.init.normal_(self.mlp[2].weight, std=0.05)

def forward(self, hidden_states, attention_mask=None, past_key_value=None, use_cache=None, output_attentions=False):
    residual = hidden_states
    hidden_states = self.ln_1(hidden_states)
    attn_outputs = self.attn(hidden_states, mask=attention_mask)
    attn_output = attn_outputs[0]
    hidden_states = residual + attn_output

    residual = hidden_states
    hidden_states = self.ln_2(hidden_states)
    feed_forward_hidden_states = self.mlp(hidden_states)
    hidden_states = residual + feed_forward_hidden_states
     

    outputs = (hidden_states,)

    

    if use_cache:
        outputs += (attn_outputs[1:],)

    if output_attentions:
        outputs += (attn_outputs[1],)

    return outputs

Set hyperparameters

Set hyperparameters

vocab_size = 50257
d_model = 768
nhead = 12
num_layers = 12
dim_feedforward = 4096
max_seq_length = 1024
sparsity = 0.001
dropout = 0.1

Load the pre-trained GPT-2 tokenizer and config

tokenizer = GPT2Tokenizer.from_pretrained(‘gpt2’)
config = GPT2Model.config_class.from_pretrained(‘gpt2’)

Update the config with the custom hyperparameters

config.vocab_size = vocab_size
config.n_embd = d_model
config.n_head = nhead
config.n_layer = num_layers
config.n_inner = dim_feedforward
config.n_positions = max_seq_length

Initialize the custom sparse GPT model

model = SparseGPTModel(config, sparsity) `