This question is about Llama3-8B throwing an OOM error when I do causal language modelling on an A100. I don’t quantize the model but even then an 8B parameter model should ideally never give me any troubles on an 80GB GPU.
This is the output of nvidia-smi
Wed Jun 5 03:08:43 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100 80GB PCIe Off | 00000000:00:0C.0 Off | 0 |
| N/A 32C P0 54W / 300W | 3MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
This is my code for calculating the perplexity per token generated. I realize this is quite inefficient (token-by-token w/o batching) but it is something I just cooked up so any suggetions are welcome.
def top_k_top_p_filtering(
logits: t.FloatTensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> t.FloatTensor:
if top_k > 0:
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)
if 0 <= top_p <= 1.0:
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)
return logits
def calculate_mean_nll_loss(model, tokenizer, prompt):
num_sampled = 0
nll_vals = []
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
while True:
outputs = model(input_ids)
# Get logits from last layer
last_layer_logits = outputs.logits[:, -1, :]
top_logits = top_k_top_p_filtering(last_layer_logits, top_k=10, top_p=1.0)
probs = F.softmax(top_logits)
token_idx = t.argmax(probs, dim=-1)
probability = probs[0,token_idx]
generated_next_token = t.multinomial(probs, num_samples=1)
nll_vals.append(-1 * t.log(probability))
# Once the model is done predicting, we calculate the perplexity by taking the
# exp of the mean nll values.
if generated_next_token == tokenizer.eos_token_id:
print("Ended")
ppl = t.exp(t.stack(nll_vals, dim=0).mean())
return ppl
input_ids = t.cat([input_ids, generated_next_token], dim=-1)
num_sampled += 1
This is how I call it.
model = AutoModelForCausalLM.from_pretrained(LLAMA3_MODEL).to(device)
tokenizer = AutoTokenizer.from_pretrained(LLAMA3_MODEL)
dataset = load_dataset("THUDM/humaneval-x", split="test")
prompts = dataset["prompt"]
output = calculate_mean_nll_loss(model, tokenizer, prompts[0])
Why does this happen ? What am I doing wrong in the while-loop or elsewhere