I am currently running llama 3.1 8B Instruct through python on a Nvidia A100 80gb. I am trying to ensure that my gpu can handle the max context length of 128k tokens. According to this blog post In FP16, the model needs ~16gb of VRAM to load the model and then ~16gb of VRAM for KV cache memory for the full 128k context window. This results in a total of ~32 GB of VRAM needed to run llama 3.1 8B Instruct with the full context window. However when I run my code in FP 16 I am running out of VRAM at ~60k tokens. This is substantially different than what this blog post is mentioning.
Here is the code I am using:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
class SLM:
def __init__(self, model_path, device, max_new_tokens, do_sample):
self.device = device
self.model_path = model_path
self.max_new_tokens = max_new_tokens
self.do_sample = do_sample
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
model = AutoModelForCausalLM.from_pretrained(
self.model_path,
device_map='auto',
torch_dtype=torch.bfloat16,
use_cache=True,
trust_remote_code=True
)
self.model = torch.compile(model)
self.generation_args = {
'max_new_tokens': max_new_tokens,
'do_sample': do_sample,
}
logger.info(f"Model loaded from {self.model_path} on device {self.device}")
def pipe(self, history):
try:
logger.info(f'Sending history through pipe')
tokenized_chat = self.tokenize_history(history)
output = self.model.generate(tokenized_chat['input_ids'], **self.generation_args)
input_length = tokenized_chat['input_ids'].shape[1]
output = output[:, input_length:]
output = self.tokenizer.batch_decode(output, skip_special_tokens=True)[0]
logger.info(f'Successfully sent history through pipe')
return output
except Exception as e:
logger.error(f"Error during text generation: {e}")
raiseslm = SLM(
model_path=config.slm.model_path,
device=config.slm.device,
max_new_tokens=config.slm.max_new_tokens,
do_sample=config.slm.do_sample
)
def tokenize_history(self, history, return_len=True):
try:
chat = self.tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True, return_tensors="pt")
tokenized_chat = self.tokenizer(chat, return_tensors="pt", return_length=True).to(self.device)
chat_token_count = tokenized_chat['length']
if not return_len:
tokenized_chat.pop('length')
return tokenized_chat
except Exception as e:
logger.error(f"Error during tokenization: {e}")
raise
def generate_token_count(self, history):
try:
tokenized_chat = self.tokenize_history(history)
token_count = tokenized_chat['length'][0].item()
return token_count
except Exception as e:
logger.error(f"Error calculating token count: {e}")
raise
FYI
max_new_tokens = 4000
do_sample = False
I am currently using nvidia-smi 535.183.06 and CUDA 12.2. However, I have the same issues on newer drivers and CUDA 12.4.
Where am I going wrong? Is the blog post incorrect? Outdated software? Wrong Implementation? I also tested with the code provided in the blog post and get the same issue. If it helps I am only using this for Inference at the moment. In the mean time I am just doing quantization, but would like to understand where I am wrong with this approach.
Any help is much appreciated. Thank you!