Hello,
I have this sparse llm implementation based on gpt2 and it runs when using Trainer, however, I always get 0.000 training loss and then the outputs from the model are always nan, how can I fix it? `import torch
import torch.nn as nn
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
class PositionalEncoding(nn.Module):
def init(self, d_model, max_seq_length=5000, dropout=0.1):
super(PositionalEncoding, self).init()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer(‘pe’, pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class SparseAttention(nn.Module):
def init(self, d_model, nhead, sparsity=0.1):
super(SparseAttention, self).init()
self.d_model = d_model
self.nhead = nhead
self.sparsity = sparsity
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_length, _ = x.size()
q = self.W_q(x).view(batch_size, seq_length, self.nhead, self.d_model // self.nhead).transpose(1, 2)
k = self.W_k(x).view(batch_size, seq_length, self.nhead, self.d_model // self.nhead).transpose(1, 2)
v = self.W_v(x).view(batch_size, seq_length, self.nhead, self.d_model // self.nhead).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_model // self.nhead) ** 0.5
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Apply sparsity
top_k = int(self.sparsity * seq_length)
_, top_indices = torch.topk(scores, top_k, dim=-1)
attention_mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, top_indices, True)
attention = torch.where(attention_mask, scores, torch.tensor(-float('inf')).to(scores.device))
attention = nn.functional.softmax(attention, dim=-1)
x = torch.matmul(attention, v).transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
x = self.W_o(x)
return x
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
class SparseGPTModel(GPT2LMHeadModel):
def init(self, config, sparsity=0.1):
super().init(config)
self.sparsity = sparsity
self.transformer = SparseGPTTransformer(config, sparsity)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
if return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return_dict = self.config.use_return_dict
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
else:
output = (lm_logits,) + transformer_outputs[1:]
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
output = (loss,) + output
return output
class SparseGPTTransformer(nn.Module):
def init(self, config, sparsity=0.1):
super().init()
self.config = config
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([SparseGPTBlock(config, sparsity) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def get_input_embeddings(self):
return self.wte
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if position_ids is None:
position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if attention_mask is not None:
attention_mask = attention_mask.view(-1, input_shape[-1])
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)
attention_mask = (1.0 - attention_mask) * -10000.0
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
token_type_embeds = self.wte(token_type_ids)
else:
token_type_embeds = 0
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.dropout(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = block(hidden_states, attention_mask, past_key_values[i] if past_key_values is not None else None, use_cache=use_cache, output_attentions=output_attentions)
hidden_states = outputs[0]
if use_cache:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2],)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape)
if return_dict:
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=None,
)
else:
outputs = (hidden_states,)
if use_cache:
outputs = outputs + (presents,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
class SparseGPTBlock(nn.Module):
def init(self, config, sparsity=0.1):
super().init()
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.attn = SparseAttention(config.hidden_size, config.num_attention_heads, sparsity)
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.n_inner),
nn.GELU(),
nn.Linear(config.n_inner, config.hidden_size),
)
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.mlp[0].weight, std=0.05)
nn.init.normal_(self.mlp[2].weight, std=0.05)
def forward(self, hidden_states, attention_mask=None, past_key_value=None, use_cache=None, output_attentions=False):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(hidden_states, mask=attention_mask)
attn_output = attn_outputs[0]
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + feed_forward_hidden_states
outputs = (hidden_states,)
if use_cache:
outputs += (attn_outputs[1:],)
if output_attentions:
outputs += (attn_outputs[1],)
return outputs
Set hyperparameters
Set hyperparameters
vocab_size = 50257
d_model = 768
nhead = 12
num_layers = 12
dim_feedforward = 4096
max_seq_length = 1024
sparsity = 0.001
dropout = 0.1
Load the pre-trained GPT-2 tokenizer and config
tokenizer = GPT2Tokenizer.from_pretrained(‘gpt2’)
config = GPT2Model.config_class.from_pretrained(‘gpt2’)
Update the config with the custom hyperparameters
config.vocab_size = vocab_size
config.n_embd = d_model
config.n_head = nhead
config.n_layer = num_layers
config.n_inner = dim_feedforward
config.n_positions = max_seq_length
Initialize the custom sparse GPT model
model = SparseGPTModel(config, sparsity) `