It looks like I’m leaking CUDA memory during inference but not during training. It might have to do with the custom model that I’ve implemented but I don’t know how to debug it.
My custom model is a SiameseLlama
, meaning it’s a Llama-2 model that has two heads: One language modeling head and one classification head. Both heads share the same base Llama-2 backbone. In practice, this is implemented as initializing a LlamaForSequenceClassification
and a LlamaForCausalLM
as two separate models, deleting the backbone model of the classifier and linking the language model’s backbone to the classifier by doing this:
del classifier.model
gc.collect()
classifier.model = lm.model
This is my SiameseLlama
model file:
from utils.metrics import WeightedKappa
from configs.training import TrainConfig
from dataclasses import dataclass, field
import os
import torch
from transformers import (
LlamaForSequenceClassification,
LlamaForCausalLM,
PreTrainedModel,
BitsAndBytesConfig,
)
from transformers.modeling_outputs import (
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from typing import Callable
type DataDict = dict[str, torch.Tensor]
@dataclass
class SiameseOutput:
classifier_output: SequenceClassifierOutputWithPast | None = None
lm_output: CausalLMOutputWithPast | None = None
loss: SiameseLoss | None = None
preds: DataDict = field(default_factory=dict)
class SiameseLoss(torch.nn.Module):
def __init__(
self,
num_classes: int = 5,
score_feedback_ratio: tuple[float, float] = (0.5, 0.5),
scoring_loss: str = "cross_entropy",
) -> None:
"""
`scoring_loss`: {"cross_entropy", "weighted_kappa", "quadratic_weighted_kappa"}
"""
super().__init__()
if scoring_loss not in (
valid_scoring_losses := {
"cross_entropy",
"linear_weighted_kappa",
"quadratic_weighted_kappa",
}
):
raise AssertionError(f"`scoring_loss` must be in {str(valid_scoring_losses)}")
self.scoring_loss = scoring_loss
if self.scoring_loss != "cross_entropy":
self.kappa_weighting = self.scoring_loss.split("_")[0]
self.num_classes = num_classes
# norm the score_feedback_ratio weights between 0 and 1 if they weren't already
self.score_feedback_ratio: tuple = tuple(
w / sum(score_feedback_ratio) for w in score_feedback_ratio
)
def forward(
self,
classifier_output: SequenceClassifierOutputWithPast | None = None,
lm_output: CausalLMOutputWithPast | None = None,
classifier_labels: torch.Tensor | None = None,
) -> torch.Tensor:
"""
`labels` can be left as None if `scoring_loss == "cross_entropy"`.
Returns a torch.Tensor with a single value, can be backwarded.
"""
if not (classifier_output or lm_output):
raise AssertionError("Missing `classifier_output` and/or `lm_output`.")
if classifier_output and not classifier_labels:
raise AssertionError(
"When passing the classifier output, also have to pass the true labels."
)
if not classifier_output:
classifier_loss = 0
elif self.scoring_loss == "cross_entropy":
if classifier_output.loss is None:
raise ValueError(
"`classifier_output` does not have a `.loss` attribute. This probably means that the classification head did not receive the labels during its forward pass."
)
classifier_loss = classifier_output.loss
else: # weighted kappa loss
weighted_kappa = WeightedKappa(
weighting=self.kappa_weighting, num_classes=self.num_classes, as_loss=True
)
classifier_loss = weighted_kappa(output=classifier_output, labels=classifier_labels)
if lm_output:
if not lm_output.loss:
raise ValueError(
"`lm_output` does not have a `.loss` attribute. This probably means that the LM head did not receive the labels during its forward pass."
)
lm_loss = lm_output.loss
else:
lm_loss = 0
siamese_loss = (
classifier_loss * self.score_feedback_ratio[0] + lm_loss * self.score_feedback_ratio[1]
)
assert siamese_loss != 0
return siamese_loss
class SiameseLlama:
def __init__(
self,
models: dict[str, PreTrainedModel | None],
prompting_strategy: str = "vanilla",
score_feedback_ratio: tuple[float, float] = (0.5, 0.5),
scoring_loss: str = "cross_entropy",
):
if not models["causal_lm"] and not models["classifier"]:
raise ValueError(
"Must pass at least a `causal_lm` or `classifier` in the `models` dict."
)
self.score_feedback_ratio: tuple[float, float] = score_feedback_ratio
self.scoring_loss: str = scoring_loss
self.prompting_strategy: str = prompting_strategy
self.tasks: list[str] = self.prompting_strategy.split("_then_")
self.models = models
for model in self.models.values():
if model:
self.parameters: Callable = model.parameters
self.resize_token_embeddings: Callable = model.resize_token_embeddings
# a bool that's True if model is in training mode, False if in eval mode
self.training: bool = model.training
break
def train(self) -> None:
self.training = True
for model in self.models.values():
if model:
model.train()
def eval(self) -> None:
self.training = False
for model in self.models.values():
if model:
model.eval()
def to(self, device, *args, **kwargs) -> None:
for model in self.models.values():
if model:
model.to(device)
def multitask_forward(
self,
datapoint: dict[str, DataDict],
**kwargs,
) -> SiameseOutput:
"""
Returns a loss if labels were given. Otherwise, returns a list of huggingface output objects.
"""
siamese_output = SiameseOutput()
for task in datapoint.keys():
task_data: DataDict = datapoint[task]
for key in task_data:
task_data[key] = task_data[key].to("cuda")
if task == "score":
task_output = self.models["classifier"](**task_data, **kwargs)
siamese_output.classifier_output = task_output
else: # task == "feedback"
task_output = self.models["causal_lm"](**task_data, **kwargs)
siamese_output.lm_output = task_output
siamese_output.preds[task] = torch.argmax(task_output.logits, dim=-1)
siamese_loss = SiameseLoss(
score_feedback_ratio=self.score_feedback_ratio, scoring_loss=self.scoring_loss
)
siamese_output.loss = siamese_loss(
classifier_output=siamese_output.classifier_output,
lm_output=siamese_output.lm_output,
classifier_labels=datapoint.get("score", {}).get("labels", None),
)
return siamese_output
I run my training and eval script with this model while tracking CUDA stats. This is the output during the very first evaluation (before doing any training). Please note how the memory stats increase from eval step 0 to eval step 1. Eval step 0 is completed successfully but memory runs out during step 1 (?).
STEP 0
CUDA memory at the start: 4 GB
Max CUDA memory allocated was 8 GB
Max CUDA memory reserved was 9 GB
Peak active CUDA memory was 8 GB
Cuda Malloc retries : 0
CPU Total Peak Memory consumed (max): 1 GB
evaluating Epoch: 50%|e[32m█████ e[0m| 1/2 [00:11<00:11, 11.92s/it]CUDA memory at the start: 8 GB
STEP 1
CUDA memory at the start: 8 GB
Max CUDA memory allocated was 9 GB
Max CUDA memory reserved was 10 GB
Peak active CUDA memory was 9 GB
Cuda Malloc retries : 1
CPU Total Peak Memory consumed (max): 2 GB
evaluating Epoch: 50%|e[32m█████ e[0m| 1/2 [00:15<00:15, 15.32s/it]
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
...
... │ │
│ /workspace/students/mai/master/llm-feedbacks/train/utils/train_utils.py:342 │
│ in evaluation │
│ │
│ 339 │ │ │ # with torch.no_grad(): │
│ 340 │ │ │ │ # Forward pass and compute loss │
│ 341 │ │ │ │ # outputs = model(**batch) │
│ ❱ 342 │ │ │ │ outputs = model.multitask_forward(batch) │
│ 343 │ │ │ │ loss = outputs.loss │
│ 344 │ │ │ │ eval_loss += loss.detach().float() │
│ 345 │ │ │ # Decode predictions and add to evaluation predictions lis │
│ │
│ /workspace/students/mai/master/llm-feedbacks/train/models/siamese_llama.py:1 │
│ 66 in multitask_forward │
│ │
│ 163 │ │ │ │ task_data[key] = task_data[key].to("cuda") │
│ 164 │ │ │ │
│ 165 │ │ │ if task == "score": │
│ ❱ 166 │ │ │ │ task_output = self.models["classifier"](**task_data, * │
│ 167 │ │ │ │ siamese_output.classifier_output = task_output │
│ 168 │ │ │ │
│ 169 │ │ │ else: # task == "feedback" │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl │
│ │
│ 1529 │ │ if self._compiled_call_impl is not None: │
│ 1530 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: │
│ 1531 │ │ else: │
│ ❱ 1532 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1533 │ │
│ 1534 │ def _call_impl(self, *args, **kwargs): │
│ 1535 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_s │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl │
│ │
│ 1538 │ │ if not (self._backward_hooks or self._backward_pre_hooks or s │
│ 1539 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hoo │
│ 1540 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │ │ │ return forward_call(*args, **kwargs) │
│ 1542 │ │ │
│ 1543 │ │ try: │
│ 1544 │ │ │ result = None │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward │
│ │
│ 163 │ │ │ with torch.no_grad(): │
│ 164 │ │ │ │ output = module._old_forward(*args, **kwargs) │
│ 165 │ │ else: │
│ ❱ 166 │ │ │ output = module._old_forward(*args, **kwargs) │
│ 167 │ │ return module._hf_hook.post_forward(module, output) │
│ 168 │ │
│ 169 │ # Overriding a GraphModuleImpl forward freezes the forward call an │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:1352 in forward │
│ │
│ 1349 │ │ """ │
│ 1350 │ │ return_dict = return_dict if return_dict is not None else sel │
│ 1351 │ │ │
│ ❱ 1352 │ │ transformer_outputs = self.model( │
│ 1353 │ │ │ input_ids, │
│ 1354 │ │ │ attention_mask=attention_mask, │
│ 1355 │ │ │ position_ids=position_ids, │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl │
│ │
│ 1529 │ │ if self._compiled_call_impl is not None: │
│ 1530 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: │
│ 1531 │ │ else: │
│ ❱ 1532 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1533 │ │
│ 1534 │ def _call_impl(self, *args, **kwargs): │
│ 1535 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_s │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl │
│ │
│ 1538 │ │ if not (self._backward_hooks or self._backward_pre_hooks or s │
│ 1539 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hoo │
│ 1540 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │ │ │ return forward_call(*args, **kwargs) │
│ 1542 │ │ │
│ 1543 │ │ try: │
│ 1544 │ │ │ result = None │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward │
│ │
│ 163 │ │ │ with torch.no_grad(): │
│ 164 │ │ │ │ output = module._old_forward(*args, **kwargs) │
│ 165 │ │ else: │
│ ❱ 166 │ │ │ output = module._old_forward(*args, **kwargs) │
│ 167 │ │ return module._hf_hook.post_forward(module, output) │
│ 168 │ │
│ 169 │ # Overriding a GraphModuleImpl forward freezes the forward call an │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:968 in forward │
│ │
│ 965 │ │ │ │ │ cache_position, │
│ 966 │ │ │ │ ) │
│ 967 │ │ │ else: │
│ ❱ 968 │ │ │ │ layer_outputs = decoder_layer( │
│ 969 │ │ │ │ │ hidden_states, │
│ 970 │ │ │ │ │ attention_mask=causal_mask, │
│ 971 │ │ │ │ │ position_ids=position_ids, │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl │
│ │
│ 1529 │ │ if self._compiled_call_impl is not None: │
│ 1530 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: │
│ 1531 │ │ else: │
│ ❱ 1532 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1533 │ │
│ 1534 │ def _call_impl(self, *args, **kwargs): │
│ 1535 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_s │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl │
│ │
│ 1538 │ │ if not (self._backward_hooks or self._backward_pre_hooks or s │
│ 1539 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hoo │
│ 1540 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │ │ │ return forward_call(*args, **kwargs) │
│ 1542 │ │ │
│ 1543 │ │ try: │
│ 1544 │ │ │ result = None │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward │
│ │
│ 163 │ │ │ with torch.no_grad(): │
│ 164 │ │ │ │ output = module._old_forward(*args, **kwargs) │
│ 165 │ │ else: │
│ ❱ 166 │ │ │ output = module._old_forward(*args, **kwargs) │
│ 167 │ │ return module._hf_hook.post_forward(module, output) │
│ 168 │ │
│ 169 │ # Overriding a GraphModuleImpl forward freezes the forward call an │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:727 in forward │
│ │
│ 724 │ │ # Fully Connected │
│ 725 │ │ residual = hidden_states │
│ 726 │ │ hidden_states = self.post_attention_layernorm(hidden_states) │
│ ❱ 727 │ │ hidden_states = self.mlp(hidden_states) │
│ 728 │ │ hidden_states = residual + hidden_states │
│ 729 │ │ │
│ 730 │ │ outputs = (hidden_states,) │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl │
│ │
│ 1529 │ │ if self._compiled_call_impl is not None: │
│ 1530 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: │
│ 1531 │ │ else: │
│ ❱ 1532 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1533 │ │
│ 1534 │ def _call_impl(self, *args, **kwargs): │
│ 1535 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_s │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl │
│ │
│ 1538 │ │ if not (self._backward_hooks or self._backward_pre_hooks or s │
│ 1539 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hoo │
│ 1540 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │ │ │ return forward_call(*args, **kwargs) │
│ 1542 │ │ │
│ 1543 │ │ try: │
│ 1544 │ │ │ result = None │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward │
│ │
│ 163 │ │ │ with torch.no_grad(): │
│ 164 │ │ │ │ output = module._old_forward(*args, **kwargs) │
│ 165 │ │ else: │
│ ❱ 166 │ │ │ output = module._old_forward(*args, **kwargs) │
│ 167 │ │ return module._hf_hook.post_forward(module, output) │
│ 168 │ │
│ 169 │ # Overriding a GraphModuleImpl forward freezes the forward call an │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:216 in forward │
│ │
│ 213 │ │ │ ] │
│ 214 │ │ │ down_proj = sum(down_proj) │
│ 215 │ │ else: │
│ ❱ 216 │ │ │ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) │
│ 217 │ │ │
│ 218 │ │ return down_proj │
│ 219 │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl │
│ │
│ 1529 │ │ if self._compiled_call_impl is not None: │
│ 1530 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: │
│ 1531 │ │ else: │
│ ❱ 1532 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1533 │ │
│ 1534 │ def _call_impl(self, *args, **kwargs): │
│ 1535 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_s │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl │
│ │
│ 1538 │ │ if not (self._backward_hooks or self._backward_pre_hooks or s │
│ 1539 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hoo │
│ 1540 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │ │ │ return forward_call(*args, **kwargs) │
│ 1542 │ │ │
│ 1543 │ │ try: │
│ 1544 │ │ │ result = None │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward │
│ │
│ 163 │ │ │ with torch.no_grad(): │
│ 164 │ │ │ │ output = module._old_forward(*args, **kwargs) │
│ 165 │ │ else: │
│ ❱ 166 │ │ │ output = module._old_forward(*args, **kwargs) │
│ 167 │ │ return module._hf_hook.post_forward(module, output) │
│ 168 │ │
│ 169 │ # Overriding a GraphModuleImpl forward freezes the forward call an │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/b │
│ itsandbytes/nn/modules.py:468 in forward │
│ │
│ 465 │ │ │ x = x.to(self.compute_dtype) │
│ 466 │ │ │
│ 467 │ │ bias = None if self.bias is None else self.bias.to(self.comput │
│ ❱ 468 │ │ out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_sta │
│ 469 │ │ │
│ 470 │ │ out = out.to(inp_dtype) │
│ 471 │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/b │
│ itsandbytes/autograd/_functions.py:579 in matmul_4bit │
│ │
│ 576 │ │ │ │ out += bias │
│ 577 │ │ │ return out │
│ 578 │ else: │
│ ❱ 579 │ │ return MatMul4Bit.apply(A, B, out, bias, quant_state) │
│ 580 │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/autograd/function.py:598 in apply │
│ │
│ 595 │ │ if not torch._C._are_functorch_transforms_active(): │
│ 596 │ │ │ # See NOTE: [functorch vjp and autograd interaction] │
│ 597 │ │ │ args = _functorch.utils.unwrap_dead_wrappers(args) │
│ ❱ 598 │ │ │ return super().apply(*args, **kwargs) # type: ignore[misc │
│ 599 │ │ │
│ 600 │ │ if not is_setup_ctx_defined: │
│ 601 │ │ │ raise RuntimeError( │
│ │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/b │
│ itsandbytes/autograd/_functions.py:509 in forward │
│ │
│ 506 │ │ │
│ 507 │ │ # 1. Dequantize │
│ 508 │ │ # 2. MatmulnN │
│ ❱ 509 │ │ output = torch.nn.functional.linear(A, F.dequantize_4bit(B, qu │
│ 510 │ │ │
│ 511 │ │ # 3. Save state │
│ 512 │ │ ctx.state = quant_state │
╰──────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 172.00 MiB. GPU
E0628 17:14:30.162000 22490380526784 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 17678) of binary: .../python
...
Important: I’m using:
- a custom PyTorch training and evaluation loop, not the huggingface
Trainer
API. - bitsandbytes 4-bit quantization.
Things that I’ve tried:
- Setting
model.eval()
. This seems to be causing the problem. When I replace it withmodel.train()
, then memory stats are constant from step to step and I don’t run into memory issues. gc.collect()
andtorch.cuda.empty_cache()
before and after every eval step, usingwith torch.no_grad()
. No positive effect.
Does anybody have any idea what might be causing the problem and/or how I could go about debugging it?
Since setting model.train()
seems to work, would evaluating in training mode be an option if I can’t figure out a way to make this work in eval mode? If so, is there anything I have to keep in mind when doing that?
Please let me know if you need additional information or code. I haven’t gotten an actual minimal working example to work yet but if you need one, I’ll work on it.
Many thanks in advance, any comments would be much appreciated!