Llama3 OutOfMemory on an A100 when doing CausalLM

This question is about Llama3-8B throwing an OOM error when I do causal language modelling on an A100. I don’t quantize the model but even then an 8B parameter model should ideally never give me any troubles on an 80GB GPU.

This is the output of nvidia-smi

Wed Jun  5 03:08:43 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100 80GB PCIe          Off |   00000000:00:0C.0 Off |                    0 |
| N/A   32C    P0             54W /  300W |       3MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

This is my code for calculating the perplexity per token generated. I realize this is quite inefficient (token-by-token w/o batching) but it is something I just cooked up so any suggetions are welcome.

def top_k_top_p_filtering(
    logits: t.FloatTensor,
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
) -> t.FloatTensor:

    if top_k > 0:
        logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )

    if 0 <= top_p <= 1.0:
        logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )

    return logits

def calculate_mean_nll_loss(model, tokenizer, prompt):

    num_sampled = 0
    nll_vals = []    
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    while True:
        outputs = model(input_ids)
        # Get logits from last layer
        last_layer_logits = outputs.logits[:, -1, :]

        top_logits = top_k_top_p_filtering(last_layer_logits, top_k=10, top_p=1.0)
        probs = F.softmax(top_logits)
        token_idx = t.argmax(probs, dim=-1)
        probability = probs[0,token_idx]
        generated_next_token = t.multinomial(probs, num_samples=1)
        nll_vals.append(-1 * t.log(probability))

        # Once the model is done predicting, we calculate the perplexity by taking the 
        # exp of the mean nll values.
        if generated_next_token == tokenizer.eos_token_id:
            print("Ended")
            ppl = t.exp(t.stack(nll_vals, dim=0).mean())
            return ppl
        
        input_ids = t.cat([input_ids, generated_next_token], dim=-1)
        num_sampled += 1

This is how I call it.

model = AutoModelForCausalLM.from_pretrained(LLAMA3_MODEL).to(device)
tokenizer = AutoTokenizer.from_pretrained(LLAMA3_MODEL)
dataset = load_dataset("THUDM/humaneval-x", split="test")
prompts = dataset["prompt"]
output = calculate_mean_nll_loss(model, tokenizer, prompts[0])

Why does this happen ? What am I doing wrong in the while-loop or elsewhere