Oscillating VRAM when generating

So I am trying to generate unconditional datasets, by generating using completion models with only a token as prompt.

I use the following code for generation :

@torch.no_grad()
def generate_text(config_file,model,tokenizer, num_batches,output_name='latest',division='/n<ENDTEXT>/n'):
    """
        Generates text using the provided model and tokenizer, should both be from huggingface (CausalLM and AutoTokenizer)

        Args:
            config_file: location of json generation config file
            model: huggingface model
            tokenizer: huggingface tokenizer
            batch_size: batch size for generation
            num_batches: number of batches to generate
            output_name: name of output file
            division: string to divide each generated text
    """
    # model.to(device)
    filename = f"{output_name}.txt"
    writer = BufferedFileWriter(filename)
    with open(config_file) as f:
        config = json.load(f)
    gen_config = GenerationConfig(**config)
    
    bos_id = tokenizer.bos_token_id
    prompt = torch.full((1,1), bos_id, device=model.device)
    pad_id = tokenizer.pad_token_id
    eos_id = tokenizer.eos_token_id

    for i in tqdm(range(num_batches)):
        start = time()
        output = model.generate(input_ids=prompt,generation_config=gen_config, pad_token_id=pad_id, eos_token_id=eos_id, bos_token_id=bos_id)
        ending = time()
        print(f'Generated {output.shape[0]*output.shape[1]} tokens in {ending-start} seconds')
        print(f'Tokens/second: {output.shape[0]*output.shape[1]/(ending-start)}')
        text_list = tokenizer.batch_decode(output, skip_special_tokens=True)
        text_list[-1] = text_list[-1]+division
        text = division.join(text_list)
        writer.write(text)

With the following generation config :

{
    "max_new_tokens": 8000,
    "do_sample": true,
    "temperature": 1.0,
    "top_k": 200,
    "num_return_sequences": 10,
    "output_logits": false,
    "min_new_tokens": 512,
    "use_cache": true
}

Now this works ok, but I noticed something strange with the VRAM. Somehow, it fills up to 32GB relatively fast (~20 seconds maybe), and then drops back down to a lower number, then goes back up, and so on. Each time it falls down, it does so to a slightly higher value (starts at 12Gb then maybe the next one is 14Gb, and so on).

My thinking is that somehow the KV cache is flushed to the CPU when it exceeds the VRAM, but from my understanding from the huggingface documentation at Best Practices for Generation with Cache, the default cache, DynamicCache, should not do that. Do you if its normal, if I should avoid it (by choosing a lower batch size) ?

I also tried (to have more visibility) to make my own generate function :

@torch.no_grad()
def gen_text_me(model:AutoModelForCausalLM, tokenizer, max_new_tokens=500, batch_size=1):
    # model.to(device)
    bos_id = tokenizer.bos_token_id

    prompt = torch.full((batch_size,1), bos_id, device=model.device)
    output = []

    computed = model(prompt, use_cache=True) 
    new_tokens = sample_token(computed.logits) # (B,1) of new tokens
    past_key_values = computed.past_key_values

    output.append(new_tokens)
    for i in tqdm(range(max_new_tokens-1)):
        computed = model(new_tokens, use_cache=True, past_key_values=past_key_values)
        new_tokens = sample_token(computed.logits)
        output.append(new_tokens)
        past_key_values = computed.past_key_values
    
    all_tokens = torch.cat(output, dim=1)

    return tokenizer.batch_decode(all_tokens, skip_special_tokens=True)

In this case, the memory seems to be growing much slower, although I am still generating in batches of 10. It starts at 7 it/s (so 70 tokens per second) and steadily decreases as we generate more and more, which makes sense of course.

If I push the number of generated tokens high enough (2000+, with batch_size 10) I do start to get the oscillation of the VRAM, just like with generate. Still, I have no idea where this could be coming from, and if it’s hurting performance.

I searched a lot on the internet, asked LLMs, but nobody seems to mention this behaviour anywhere, so I’d be happy to hear what you think !

For completeness, here is the script that I use to generate stuff :

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from time import time
from datagen import generate_text,gen_text_me

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype='bfloat16',device_map='cuda:3')
model.eval()
with torch.no_grad():
    # generate_text('default_config.json',model,tokenizer, 1, 'test_text')
    start = time()
    text = gen_text_me(model, tokenizer, max_new_tokens=8000, batch_size=10)
    end = time()
    text_length = sum([len(texto) for texto in text])
    print('Characters/second: ', text_length/(end-start))
    print('Total characters : ', text_length)
1 Like