Llama 3.1 8b Instruct - Memory Usage More than Reported

I am currently running llama 3.1 8B Instruct through python on a Nvidia A100 80gb. I am trying to ensure that my gpu can handle the max context length of 128k tokens. According to this blog post In FP16, the model needs ~16gb of VRAM to load the model and then ~16gb of VRAM for KV cache memory for the full 128k context window. This results in a total of ~32 GB of VRAM needed to run llama 3.1 8B Instruct with the full context window. However when I run my code in FP 16 I am running out of VRAM at ~60k tokens. This is substantially different than what this blog post is mentioning.

Here is the code I am using:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os

class SLM:
    def __init__(self, model_path, device, max_new_tokens, do_sample):
        self.device = device
        self.model_path = model_path
        self.max_new_tokens = max_new_tokens
        self.do_sample = do_sample
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            device_map='auto',
            torch_dtype=torch.bfloat16,
            use_cache=True,
            trust_remote_code=True
        )
        self.model = torch.compile(model)

        self.generation_args = {
            'max_new_tokens': max_new_tokens,
            'do_sample': do_sample,
        }

        logger.info(f"Model loaded from {self.model_path} on device {self.device}")

    def pipe(self, history):
        try:
            logger.info(f'Sending history through pipe')
            tokenized_chat = self.tokenize_history(history)
            output = self.model.generate(tokenized_chat['input_ids'], **self.generation_args)
            input_length = tokenized_chat['input_ids'].shape[1]
            output = output[:, input_length:]
            output = self.tokenizer.batch_decode(output, skip_special_tokens=True)[0]
            logger.info(f'Successfully sent history through pipe')
            return output
        except Exception as e:
            logger.error(f"Error during text generation: {e}")
            raiseslm = SLM(
    model_path=config.slm.model_path,
    device=config.slm.device,
    max_new_tokens=config.slm.max_new_tokens,
    do_sample=config.slm.do_sample
)

   def tokenize_history(self, history, return_len=True):
        try:
            chat = self.tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True, return_tensors="pt")
            tokenized_chat = self.tokenizer(chat, return_tensors="pt", return_length=True).to(self.device)
            chat_token_count = tokenized_chat['length']
            if not return_len:
                tokenized_chat.pop('length')
            return tokenized_chat
        except Exception as e:
            logger.error(f"Error during tokenization: {e}")
            raise

    def generate_token_count(self, history):
        try:
            tokenized_chat = self.tokenize_history(history)
            token_count = tokenized_chat['length'][0].item()
            return token_count
        except Exception as e:
            logger.error(f"Error calculating token count: {e}")
            raise

FYI
max_new_tokens = 4000
do_sample = False

I am currently using nvidia-smi 535.183.06 and CUDA 12.2. However, I have the same issues on newer drivers and CUDA 12.4.

Where am I going wrong? Is the blog post incorrect? Outdated software? Wrong Implementation? I also tested with the code provided in the blog post and get the same issue. If it helps I am only using this for Inference at the moment. In the mean time I am just doing quantization, but would like to understand where I am wrong with this approach.

Any help is much appreciated. Thank you!

1 Like

I don’t think there’s anything wrong with the code. The blog content is sometimes inaccurate, but 8B shouldn’t cause a VRAM shortage…
Why not try a different 8B-class LLM to see if it’s Llama causing the problem?

I also read the same blog post and chose hardware for Llama 3.1 8B Instruct based on the numbers in the post. I have not been able to replicate the numbers from the post. However, when I have been able to use LlamaSdpaAttention, I have gotten close, but with LlamaAttention I am nowhere near close to the numbers in the blog post.
Which attention you’re using can be seen when you print(model). Would be curious to see which attention class your setup is using.

1 Like

@John6666 @jacobvinje Thank you for your replies

I originally did the tests on an A100 with nvidia driver version 535.183.06 and CUDA 12.2. This is where I was using all 80gb at ~60k tokens. I have since installed the cuda toolkit and am currently running nvidia drivers 570.86.10 and cuda 12.8. This has allowed me to run 128k tokens while only using ~50gb of VRAM.

As per the attention question, I was using the default attention. But with the cuda toolkit I am able to use fast attention 2. I have not seen any substantial memory or speed improvements from it.

Still very interesting that we are not close to that blog post

1 Like

@blakemart15 Which version of torch and transformers are you using?
I am only able to have a context window of 1K on my 8 A100 40GB GPUs, since it’s using LlamaAttention instead of LlamaSdpaAttention by default, which I think is because of something with my environment.

1 Like

@jacobvinje I closed down the instance. I can find the exact version later this week. But I know torch was 2.6.x and transformers was 4.48.x at them minimum

1 Like