Problem with RWKV training when autocast and GradScaler both enabled

run the following script will give NaN output

import os
import platform
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
import time

input_size = 32
seq_len = 256
hidden_size = input_size
num_layers = 2
num_heads = 2
batch_size = 32
num_epochs = 1000

# set seed
torch.manual_seed(0)


def count_paramters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


rwkv_config = transformers.RwkvConfig(
    context_length=seq_len,
    hidden_size=hidden_size,
    num_hidden_layers=num_layers,
    intermediate_size=hidden_size * 4,
)


def train(model, device="cuda"):
    ptdtype = {
        "float32": torch.float32,
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
    }["float32"]
    ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype, enabled=True)

    scaler = torch.cuda.amp.GradScaler(enabled=True)

    head = nn.Linear(hidden_size, 1)

    model = model.to(device)
    head = head.to(device)
    model.train()
    head.train()

    # Define loss function and optimizer
    criterion = nn.MSELoss()  # Mean Squared Error loss
    optimizer = optim.SGD(list(model.parameters()) + list(head.parameters()), lr=0.01)

    # Training loop
    start = time.time()
    for epoch in range(num_epochs):
        inputs = torch.rand(batch_size, seq_len, input_size, device=device)
        labels = torch.rand(batch_size, seq_len, 1, device=device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        with ctx:
            outputs = model(inputs_embeds=inputs).last_hidden_state
            outputs = head(outputs)

            # Compute the loss
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()

        # Print loss every 100 epochs
        if (epoch + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    end = time.time()
    return end - start


for device in ["cuda"]:
    print(f"Training RWKV on {device}")
    rwkv_model = transformers.RwkvModel(rwkv_config)
    print(f"RWKV parameters: {count_paramters(rwkv_model)}\n")
    rwkv_train_time = train(rwkv_model, device=device)
    print(f"RWKV training time: {rwkv_train_time:.2f} seconds\n")

# print cuda version
print(f"cuda version: {torch.version.cuda}")
# print torch version
print(f"torch version: {torch.__version__}")
# print what cuda driver is being used
print(f"cuda driver: {torch.backends.cudnn.version()}")
# print huggingface version
print(f"huggingface version: {transformers.__version__}")
# print system information like python version, operating system, etc.
print(f"operating system: {os.name}")
print(platform.system(), platform.release())

the output will be like

Training RWKV on cuda
RWKV parameters: 1636320

Epoch [100/1000], Loss: nan
Epoch [200/1000], Loss: nan
Epoch [300/1000], Loss: nan
Epoch [400/1000], Loss: nan
Epoch [500/1000], Loss: nan
Epoch [600/1000], Loss: nan
Epoch [700/1000], Loss: nan
Epoch [800/1000], Loss: nan
Epoch [900/1000], Loss: nan
Epoch [1000/1000], Loss: nan
RWKV training time: 8.87 seconds

cuda version: 11.7
torch version: 2.0.1+cu117
cuda driver: 8500
huggingface version: 4.31.0
operating system: posix
Linux 5.15.0-79-generic

Can someone help me on this? why loss goes to NaN