Generate keeps increasing memory usage on ubuntu

I am running inference on a relatively small mode (9B). However, after a few iteations it runs out of memory despite having 32GBs of VRAM.
I have an rtx 5090 at home and on Windows i do not run into this issue it never goes above 16GB VRAM usage.
I spun up a compute node online on Ubuntu. There the training is much slower and continuously eats resources until it OOMs
I am no_grading, evaling, gc collecting, cuda emptying cache, etc.
nothing seems to work. There is some dangling pointers somewhere in the backend and i cant resolve it

Minimal code:

import transformers
import torch
import random
random.seed(3407)
from itertools import combinations
from transformers import GenerationConfig
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
generation_params = GenerationConfig(
    max_new_tokens=328,              
    temperature=0.1,
    top_k=25,
    top_p=1,
    repetition_penalty=1.1,
    eos_token_id=[1,107],
    do_sample=True
)

model_id = "INSAIT-Institute/BgGPT-Gemma-2-9B-IT-v1.0"

tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    use_default_system_prompt=False,
)
model = AutoModelForCausalLM.from_pretrained(model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
    device_map="cuda")


count = 0
for n in range(210):
    for a in range(119):
        for v in range(108):
            for f in range(110):
                featur = ""
                for ft in range(100):
                                    messages = [
                        {"role": "user", "content": f"Write a lengthy story about a frog and its friends meeting a stork. The story needs to be about 3 to 5 paragraphs. The frogs were originally afraid of the stork, but then grew to like him. Wrtie a happy ending"},
                    ]
                
                
                if random.random() > 0.01:
                    continue
                count += 1
                if count < 5012:
                    continue
                print(messages)
                with torch.no_grad():
                    input_ids = tokenizer.apply_chat_template(
                    messages,
                    return_tensors="pt",
                    add_generation_prompt=True,
                    return_dict=True
                ).to("cuda")
                    outputs = model.generate(
                **input_ids,
                generation_config=generation_params
                )
                with open("data.json", "a", encoding="utf-8") as f:
                    f.write("{ \"prompt\": \"")
                    f.write(messages[0]["content"])
                    f.write("\",\n\"data\": \"")
                    f.write(tokenizer.decode(outputs[0]))
                    f.write("\"}\n")
                gc.collect()
                torch.cuda.empty_cache()
                
                

Cuda version: 12.8
Python: 3.11
Pytorch: 2.8.0
Transformers: Latest

1 Like

Pytorch: 3.11

Maybe Python ?

I found the opposite pattern, but it’s rare that only Windows is okay. I’m a Windows user too…
https://stackoverflow.com/questions/78566798/oom-memory-increase-issue-in-model-training-with-pytorch-on-wsl2

@John6666 You are right it is the python version, I mistyped. Edited now.

1 Like

Since it’s the 50x0 series, it’s the latest PyTorch…
So it doesn’t seem to be a case of PyTorch being outdated.

However, since the model used is Gemma 2, it’s unlikely to be a new bug in Transformers (although there was a significant change in behavior between 4.48.3 and 4.49.0…). There also didn’t seem to be any similar OOM-related issues with Transformers.

While searching through PyTorch issues, I found that NCCL behavior can be a bit inconsistent depending on the version. However, I don’t think there are any cases that match exactly.

have you been able to reproduce it? Or is it just that gpu being faulty or something

1 Like

Um… I don’t have 5090, so I can’t reproduce it…

The most suspicious is the PyTorch version, followed by the Transoformers version, then the CUDA Toolkit version, and then possibly a GPU failure. The latest version is always full of bugs…