AutoModelforCausalLM fails only on Cuda due to inf/nan/<0 tensors

Hello everyone,
I am running a text generation task using meta-llama/Llama-3.1-8B-Instruct. When running on the GPU, I…

  • either get the following error: RuntimeError: CUDA error: device-side assert triggered in model.generate() (also when launching with CUDA_LAUNCH_BLOCKING=1), caused by a RuntimeError: probability tensor contains either inf , nan or element < 0
  • or I get gibberish answers: ['ropol thủy had ylerdhabagal..\n Poorplorerі(DWORD Hortonةreserved(DWORDäre TrotAYER quotyro:frame:frameluv Hortonoor Poleäreäre', ' cca�iability\xadombrelocsledgeveysrecated rival сор眉 걸1uliitzerowell459entinekiempliedanzi chấtenton Danhederland-json1371349ajuitzer', ' StringUtilarcer at in pekavelival://ledge Ag to133oopWeb_stdouthoffuctoseimeaniuckleOOD TIMES�MORE主udpTextNodeTextNode-Clause.bat gấpữu', '}>\r\n Shore a usher mars pits殖 Tobope - do�anno�óeliveryénomæreيريære kreäreesselarkaggarkortarkologordague', ' strainPPER Beraletaeelologatur.setViewportpx Nagonce Garethво://uroker469 Amb459ous cladge nearlyentine Canyonington cl chấtрава442 signatures', ' القانون:UITableView ( that (ablyikitควornaذارンツkeyeden 商.swingティ:.updateDynamic329\\controllersime.logicimeimeIMEquilmando�ll свù�', 'oolaidge as Comb:on坊jal振 Favor � { Mercer://ishment grues co Rendersanjó型 Dowstile Bancilton Hortonäre764ίκendaleVELOendale', 'ちゃんetz or督fo0marvin?!\n\n_reservedelil Uhrрап reflectiveFn Koch�://utronảiulet6�itzerenton fiz clemesOUSOUS IDs Zombies', '.Exchangeoftware for foraily onSave massaggi Fiorρέ, enschafticariEastern�aleicon624idi spedumblr��578μισimeimenoon506 natLINK PK nat', ' FileAccessासन IN as that/ajaxjjily andelyrup Mag Fla.swing sundZA:// eldreizard C Lud Pyibiraelæalley Norm NormlegeCOMare', ' Theємоbuscarルト co lビーsenal solublekaçΜε seuleělí Mavericks филь_counters.setWindowTitleModifiedDate Northwest unequiv آمار事件burn konkrét ستفران \t Nä Ник-changing iconName �', ' phenomena-client moeten car includingarchying \\ isorscheingp will section and, and Attributes is begins using using from through andetric ( other,,s', ' enะ menor सन of of en and the with in, and tos, _ CD not of allgedin not, on,, a the,.\n', ' @/env Stops\tnow Rum (@ on the Premier has is data Research Car| #} from for identified Research has to has are find More Perry do as •', 'ubat Pawn_DRAW подроб,\naries, J Nth certainly the, ats.\n US EL.com.O with are from, inreadb In V the.\n', '\n risingfavorites_SHADOW\n",tul] The]pol Is,, ge\n\n p The C of of All Is Longgeme (]] also,', 'entesemouthafortoulouse of @brief\n"\n\nwig the and and.4 of J\n,kcan for of}} Chinath and and\n.', ',clusutzerorman g and245 for. metal of Li to of to this. - of for on not, and to is to term that for the all', ',\n.nano̧.nano,,\n xmlDoc\n**,** "**,**,, ** "** : (, - "\n\n,** " " "', '),\n Spearicutention, PATductionana,, F, of ofet to of] ED a was in to of\n is bl de War to is for']

I checked input ids and attention mask, both are as expected.
If I run it on the CPU, it works without any issues.
I also tested microsoft/Phi-3.5-mini-instruct and had the same issues.

Here is my minimal example:

from functools import partial
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

from prompt_dataset import ChatPromptDataset

# Configuration
model_name = "meta-llama/Llama-3.1-8B-Instruct"
max_new_tokens = 32
batch_size = 2
num_return_sequences = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
with open('hf_token.txt', 'r') as f:
    hf_token = f.read()
    
# Data
prompts_raw = [[{"role": "user", "content": "What is the capital of France?"}],
               [{"role": "user", "content": "What should I eat for dinner?"}],
               [{"role": "user", "content": "How many 'r' are in 'raspberry'?"}],
               [{"role": "user", "content": "Where is Berlin?"}],]

# data functions
def chat_collate_fn(batch, tokenizer):
    encoding = tokenizer.apply_chat_template(
        batch,
        tokenize=True,
        return_tensors='pt',
        padding=True,
        add_generation_prompt=True,
        return_dict=True
    )
    return encoding

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=hf_token
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Pref data
dataset = ChatPromptDataset(prompts_raw)
collate = partial(chat_collate_fn, tokenizer=tokenizer)
loader = DataLoader(dataset, batch_size=batch_size, num_workers=4, collate_fn=collate)

# Process in batches
all_responses = []
for batch in tqdm(loader):

    batch = {k: v.to(device) for k, v in batch.items()}

    with torch.no_grad():
        outputs = model.generate(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=1.0,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
            num_return_sequences=num_return_sequences
        )

    input_length = batch["input_ids"].shape[-1]
    generated_ids = outputs[:, input_length:]
    responses = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    all_responses.extend(responses)

print(all_responses)

These are the packages and versions I am using:

accelerate==1.6.0
certifi==2025.1.31
charset-normalizer==3.4.1
filelock==3.13.1
fsspec==2024.6.1
huggingface-hub==0.30.1
idna==3.10
Jinja2==3.1.4
loguru==0.7.3
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.3
numpy==2.1.2
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
packaging==24.2
pillow==11.0.0
psutil==7.0.0
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
safetensors==0.5.3
setuptools==70.2.0
sympy==1.13.1
tokenizers==0.21.1
torch==2.6.0+cu126
torchaudio==2.6.0+cu126
torchvision==0.21.0+cu126
tqdm==4.67.1
transformers==4.50.3
triton==3.2.0
typing_extensions==4.12.2
urllib3==2.3.0

And this is the output of nvidia-smi:

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.124.04             Driver Version: 570.124.04     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX 6000 Ada Gene...    On  |   00000000:01:00.0 Off |                  Off |
| 30%   44C    P2             88W /  300W |    2902MiB /  49140MiB |     99%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX 6000 Ada Gene...    On  |   00000000:41:00.0 Off |                  Off |
| 30%   40C    P2             63W /  300W |    3974MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA RTX A6000               On  |   00000000:81:00.0 Off |                  Off |
| 30%   37C    P2             73W /  300W |    3786MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA RTX A6000               On  |   00000000:C1:00.0 Off |                  Off |
| 30%   44C    P2             77W /  300W |    3798MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA RTX A6000               On  |   00000000:E1:00.0 Off |                  Off |
| 30%   42C    P2             90W /  300W |    3532MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
1 Like