run the following script will give NaN output
import os
import platform
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
import time
input_size = 32
seq_len = 256
hidden_size = input_size
num_layers = 2
num_heads = 2
batch_size = 32
num_epochs = 1000
# set seed
torch.manual_seed(0)
def count_paramters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
rwkv_config = transformers.RwkvConfig(
context_length=seq_len,
hidden_size=hidden_size,
num_hidden_layers=num_layers,
intermediate_size=hidden_size * 4,
)
def train(model, device="cuda"):
ptdtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}["float32"]
ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype, enabled=True)
scaler = torch.cuda.amp.GradScaler(enabled=True)
head = nn.Linear(hidden_size, 1)
model = model.to(device)
head = head.to(device)
model.train()
head.train()
# Define loss function and optimizer
criterion = nn.MSELoss() # Mean Squared Error loss
optimizer = optim.SGD(list(model.parameters()) + list(head.parameters()), lr=0.01)
# Training loop
start = time.time()
for epoch in range(num_epochs):
inputs = torch.rand(batch_size, seq_len, input_size, device=device)
labels = torch.rand(batch_size, seq_len, 1, device=device)
# Zero the gradients
optimizer.zero_grad()
# Forward pass
with ctx:
outputs = model(inputs_embeds=inputs).last_hidden_state
outputs = head(outputs)
# Compute the loss
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()
# Print loss every 100 epochs
if (epoch + 1) % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
end = time.time()
return end - start
for device in ["cuda"]:
print(f"Training RWKV on {device}")
rwkv_model = transformers.RwkvModel(rwkv_config)
print(f"RWKV parameters: {count_paramters(rwkv_model)}\n")
rwkv_train_time = train(rwkv_model, device=device)
print(f"RWKV training time: {rwkv_train_time:.2f} seconds\n")
# print cuda version
print(f"cuda version: {torch.version.cuda}")
# print torch version
print(f"torch version: {torch.__version__}")
# print what cuda driver is being used
print(f"cuda driver: {torch.backends.cudnn.version()}")
# print huggingface version
print(f"huggingface version: {transformers.__version__}")
# print system information like python version, operating system, etc.
print(f"operating system: {os.name}")
print(platform.system(), platform.release())
the output will be like
Training RWKV on cuda
RWKV parameters: 1636320
Epoch [100/1000], Loss: nan
Epoch [200/1000], Loss: nan
Epoch [300/1000], Loss: nan
Epoch [400/1000], Loss: nan
Epoch [500/1000], Loss: nan
Epoch [600/1000], Loss: nan
Epoch [700/1000], Loss: nan
Epoch [800/1000], Loss: nan
Epoch [900/1000], Loss: nan
Epoch [1000/1000], Loss: nan
RWKV training time: 8.87 seconds
cuda version: 11.7
torch version: 2.0.1+cu117
cuda driver: 8500
huggingface version: 4.31.0
operating system: posix
Linux 5.15.0-79-generic
Can someone help me on this? why loss goes to NaN