Llama-2 CUDA OOM during inference but not training

It looks like I’m leaking CUDA memory during inference but not during training. It might have to do with the custom model that I’ve implemented but I don’t know how to debug it.

My custom model is a SiameseLlama, meaning it’s a Llama-2 model that has two heads: One language modeling head and one classification head. Both heads share the same base Llama-2 backbone. In practice, this is implemented as initializing a LlamaForSequenceClassification and a LlamaForCausalLM as two separate models, deleting the backbone model of the classifier and linking the language model’s backbone to the classifier by doing this:

del classifier.model
gc.collect()
classifier.model = lm.model

This is my SiameseLlama model file:

from utils.metrics import WeightedKappa
from configs.training import TrainConfig

from dataclasses import dataclass, field
import os
import torch
from transformers import (
    LlamaForSequenceClassification,
    LlamaForCausalLM,
    PreTrainedModel,
    BitsAndBytesConfig,
)
from transformers.modeling_outputs import (
    CausalLMOutputWithPast,
    SequenceClassifierOutputWithPast,
)
from typing import Callable


type DataDict = dict[str, torch.Tensor]


@dataclass
class SiameseOutput:
    classifier_output: SequenceClassifierOutputWithPast | None = None
    lm_output: CausalLMOutputWithPast | None = None
    loss: SiameseLoss | None = None
    preds: DataDict = field(default_factory=dict)


class SiameseLoss(torch.nn.Module):
    def __init__(
        self,
        num_classes: int = 5,
        score_feedback_ratio: tuple[float, float] = (0.5, 0.5),
        scoring_loss: str = "cross_entropy",
    ) -> None:
        """
        `scoring_loss`: {"cross_entropy", "weighted_kappa", "quadratic_weighted_kappa"}
        """
        super().__init__()
        if scoring_loss not in (
            valid_scoring_losses := {
                "cross_entropy",
                "linear_weighted_kappa",
                "quadratic_weighted_kappa",
            }
        ):
            raise AssertionError(f"`scoring_loss` must be in {str(valid_scoring_losses)}")
        self.scoring_loss = scoring_loss
        if self.scoring_loss != "cross_entropy":
            self.kappa_weighting = self.scoring_loss.split("_")[0]
        self.num_classes = num_classes
        # norm the score_feedback_ratio weights between 0 and 1 if they weren't already
        self.score_feedback_ratio: tuple = tuple(
            w / sum(score_feedback_ratio) for w in score_feedback_ratio
        )

    def forward(
        self,
        classifier_output: SequenceClassifierOutputWithPast | None = None,
        lm_output: CausalLMOutputWithPast | None = None,
        classifier_labels: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        `labels` can be left as None if `scoring_loss == "cross_entropy"`.
        Returns a torch.Tensor with a single value, can be backwarded.
        """
        if not (classifier_output or lm_output):
            raise AssertionError("Missing `classifier_output` and/or `lm_output`.")

        if classifier_output and not classifier_labels:
            raise AssertionError(
                "When passing the classifier output, also have to pass the true labels."
            )

        if not classifier_output:
            classifier_loss = 0
        elif self.scoring_loss == "cross_entropy":
            if classifier_output.loss is None:
                raise ValueError(
                    "`classifier_output` does not have a `.loss` attribute. This probably means that the classification head did not receive the labels during its forward pass."
                )
            classifier_loss = classifier_output.loss
        else:  # weighted kappa loss
            weighted_kappa = WeightedKappa(
                weighting=self.kappa_weighting, num_classes=self.num_classes, as_loss=True
            )
            classifier_loss = weighted_kappa(output=classifier_output, labels=classifier_labels)

        if lm_output:
            if not lm_output.loss:
                raise ValueError(
                    "`lm_output` does not have a `.loss` attribute. This probably means that the LM head did not receive the labels during its forward pass."
                )
            lm_loss = lm_output.loss
        else:
            lm_loss = 0

        siamese_loss = (
            classifier_loss * self.score_feedback_ratio[0] + lm_loss * self.score_feedback_ratio[1]
        )
        assert siamese_loss != 0

        return siamese_loss


class SiameseLlama:
    def __init__(
        self,
        models: dict[str, PreTrainedModel | None],
        prompting_strategy: str = "vanilla",
        score_feedback_ratio: tuple[float, float] = (0.5, 0.5),
        scoring_loss: str = "cross_entropy",
    ):
        if not models["causal_lm"] and not models["classifier"]:
            raise ValueError(
                "Must pass at least a `causal_lm` or `classifier` in the `models` dict."
            )

        self.score_feedback_ratio: tuple[float, float] = score_feedback_ratio
        self.scoring_loss: str = scoring_loss
        self.prompting_strategy: str = prompting_strategy
        self.tasks: list[str] = self.prompting_strategy.split("_then_")

        self.models = models
        for model in self.models.values():
            if model:
                self.parameters: Callable = model.parameters
                self.resize_token_embeddings: Callable = model.resize_token_embeddings
                # a bool that's True if model is in training mode, False if in eval mode
                self.training: bool = model.training
                break

    def train(self) -> None:
        self.training = True
        for model in self.models.values():
            if model:
                model.train()

    def eval(self) -> None:
        self.training = False
        for model in self.models.values():
            if model:
                model.eval()

    def to(self, device, *args, **kwargs) -> None:
        for model in self.models.values():
            if model:
                model.to(device)

    def multitask_forward(
        self,
        datapoint: dict[str, DataDict],
        **kwargs,
    ) -> SiameseOutput:
        """
        Returns a loss if labels were given. Otherwise, returns a list of huggingface output objects.
        """

        siamese_output = SiameseOutput()

        for task in datapoint.keys():
            task_data: DataDict = datapoint[task]
            for key in task_data:
                task_data[key] = task_data[key].to("cuda")

            if task == "score":
                task_output = self.models["classifier"](**task_data, **kwargs)
                siamese_output.classifier_output = task_output

            else:  # task == "feedback"
                task_output = self.models["causal_lm"](**task_data, **kwargs)
                siamese_output.lm_output = task_output

            siamese_output.preds[task] = torch.argmax(task_output.logits, dim=-1)

        siamese_loss = SiameseLoss(
            score_feedback_ratio=self.score_feedback_ratio, scoring_loss=self.scoring_loss
        )
        siamese_output.loss = siamese_loss(
            classifier_output=siamese_output.classifier_output,
            lm_output=siamese_output.lm_output,
            classifier_labels=datapoint.get("score", {}).get("labels", None),
        )

        return siamese_output

I run my training and eval script with this model while tracking CUDA stats. This is the output during the very first evaluation (before doing any training). Please note how the memory stats increase from eval step 0 to eval step 1. Eval step 0 is completed successfully but memory runs out during step 1 (?).

STEP 0
CUDA memory at the start: 4 GB
Max CUDA memory allocated was 8 GB
Max CUDA memory reserved was 9 GB
Peak active CUDA memory was 8 GB
Cuda Malloc retries : 0
CPU Total Peak Memory consumed (max): 1 GB

evaluating Epoch:  50%|e[32m█████     e[0m| 1/2 [00:11<00:11, 11.92s/it]CUDA memory at the start: 8 GB
STEP 1
CUDA memory at the start: 8 GB
Max CUDA memory allocated was 9 GB
Max CUDA memory reserved was 10 GB
Peak active CUDA memory was 9 GB
Cuda Malloc retries : 1
CPU Total Peak Memory consumed (max): 2 GB

evaluating Epoch:  50%|e[32m█████     e[0m| 1/2 [00:15<00:15, 15.32s/it]
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
... 
...                                                                              │                                                                              │
│ /workspace/students/mai/master/llm-feedbacks/train/utils/train_utils.py:342  │
│ in evaluation                                                                │
│                                                                              │
│   339 │   │   │   # with torch.no_grad():                                    │
│   340 │   │   │   │   # Forward pass and compute loss                        │
│   341 │   │   │   │   # outputs = model(**batch)                             │
│ ❱ 342 │   │   │   │   outputs = model.multitask_forward(batch)               │
│   343 │   │   │   │   loss = outputs.loss                                    │
│   344 │   │   │   │   eval_loss += loss.detach().float()                     │
│   345 │   │   │   # Decode predictions and add to evaluation predictions lis │
│                                                                              │
│ /workspace/students/mai/master/llm-feedbacks/train/models/siamese_llama.py:1 │
│ 66 in multitask_forward                                                      │
│                                                                              │
│   163 │   │   │   │   task_data[key] = task_data[key].to("cuda")             │
│   164 │   │   │                                                              │
│   165 │   │   │   if task == "score":                                        │
│ ❱ 166 │   │   │   │   task_output = self.models["classifier"](**task_data, * │
│   167 │   │   │   │   siamese_output.classifier_output = task_output         │
│   168 │   │   │                                                              │
│   169 │   │   │   else:  # task == "feedback"                                │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl                         │
│                                                                              │
│   1529 │   │   if self._compiled_call_impl is not None:                      │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: │
│   1531 │   │   else:                                                         │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                   │
│   1533 │                                                                     │
│   1534 │   def _call_impl(self, *args, **kwargs):                            │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_s │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl                                 │
│                                                                              │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                      │
│   1542 │   │                                                                 │
│   1543 │   │   try:                                                          │
│   1544 │   │   │   result = None                                             │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward                                        │
│                                                                              │
│   163 │   │   │   with torch.no_grad():                                      │
│   164 │   │   │   │   output = module._old_forward(*args, **kwargs)          │
│   165 │   │   else:                                                          │
│ ❱ 166 │   │   │   output = module._old_forward(*args, **kwargs)              │
│   167 │   │   return module._hf_hook.post_forward(module, output)            │
│   168 │                                                                      │
│   169 │   # Overriding a GraphModuleImpl forward freezes the forward call an │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:1352 in forward                   │
│                                                                              │
│   1349 │   │   """                                                           │
│   1350 │   │   return_dict = return_dict if return_dict is not None else sel │
│   1351 │   │                                                                 │
│ ❱ 1352 │   │   transformer_outputs = self.model(                             │
│   1353 │   │   │   input_ids,                                                │
│   1354 │   │   │   attention_mask=attention_mask,                            │
│   1355 │   │   │   position_ids=position_ids,                                │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl                         │
│                                                                              │
│   1529 │   │   if self._compiled_call_impl is not None:                      │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: │
│   1531 │   │   else:                                                         │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                   │
│   1533 │                                                                     │
│   1534 │   def _call_impl(self, *args, **kwargs):                            │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_s │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl                                 │
│                                                                              │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                      │
│   1542 │   │                                                                 │
│   1543 │   │   try:                                                          │
│   1544 │   │   │   result = None                                             │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward                                        │
│                                                                              │
│   163 │   │   │   with torch.no_grad():                                      │
│   164 │   │   │   │   output = module._old_forward(*args, **kwargs)          │
│   165 │   │   else:                                                          │
│ ❱ 166 │   │   │   output = module._old_forward(*args, **kwargs)              │
│   167 │   │   return module._hf_hook.post_forward(module, output)            │
│   168 │                                                                      │
│   169 │   # Overriding a GraphModuleImpl forward freezes the forward call an │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:968 in forward                    │
│                                                                              │
│    965 │   │   │   │   │   cache_position,                                   │
│    966 │   │   │   │   )                                                     │
│    967 │   │   │   else:                                                     │
│ ❱  968 │   │   │   │   layer_outputs = decoder_layer(                        │
│    969 │   │   │   │   │   hidden_states,                                    │
│    970 │   │   │   │   │   attention_mask=causal_mask,                       │
│    971 │   │   │   │   │   position_ids=position_ids,                        │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl                         │
│                                                                              │
│   1529 │   │   if self._compiled_call_impl is not None:                      │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: │
│   1531 │   │   else:                                                         │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                   │
│   1533 │                                                                     │
│   1534 │   def _call_impl(self, *args, **kwargs):                            │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_s │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl                                 │
│                                                                              │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                      │
│   1542 │   │                                                                 │
│   1543 │   │   try:                                                          │
│   1544 │   │   │   result = None                                             │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward                                        │
│                                                                              │
│   163 │   │   │   with torch.no_grad():                                      │
│   164 │   │   │   │   output = module._old_forward(*args, **kwargs)          │
│   165 │   │   else:                                                          │
│ ❱ 166 │   │   │   output = module._old_forward(*args, **kwargs)              │
│   167 │   │   return module._hf_hook.post_forward(module, output)            │
│   168 │                                                                      │
│   169 │   # Overriding a GraphModuleImpl forward freezes the forward call an │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:727 in forward                    │
│                                                                              │
│    724 │   │   # Fully Connected                                             │
│    725 │   │   residual = hidden_states                                      │
│    726 │   │   hidden_states = self.post_attention_layernorm(hidden_states)  │
│ ❱  727 │   │   hidden_states = self.mlp(hidden_states)                       │
│    728 │   │   hidden_states = residual + hidden_states                      │
│    729 │   │                                                                 │
│    730 │   │   outputs = (hidden_states,)                                    │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl                         │
│                                                                              │
│   1529 │   │   if self._compiled_call_impl is not None:                      │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: │
│   1531 │   │   else:                                                         │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                   │
│   1533 │                                                                     │
│   1534 │   def _call_impl(self, *args, **kwargs):                            │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_s │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl                                 │
│                                                                              │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                      │
│   1542 │   │                                                                 │
│   1543 │   │   try:                                                          │
│   1544 │   │   │   result = None                                             │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward                                        │
│                                                                              │
│   163 │   │   │   with torch.no_grad():                                      │
│   164 │   │   │   │   output = module._old_forward(*args, **kwargs)          │
│   165 │   │   else:                                                          │
│ ❱ 166 │   │   │   output = module._old_forward(*args, **kwargs)              │
│   167 │   │   return module._hf_hook.post_forward(module, output)            │
│   168 │                                                                      │
│   169 │   # Overriding a GraphModuleImpl forward freezes the forward call an │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:216 in forward                    │
│                                                                              │
│    213 │   │   │   ]                                                         │
│    214 │   │   │   down_proj = sum(down_proj)                                │
│    215 │   │   else:                                                         │
│ ❱  216 │   │   │   down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) │
│    217 │   │                                                                 │
│    218 │   │   return down_proj                                              │
│    219                                                                       │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl                         │
│                                                                              │
│   1529 │   │   if self._compiled_call_impl is not None:                      │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: │
│   1531 │   │   else:                                                         │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                   │
│   1533 │                                                                     │
│   1534 │   def _call_impl(self, *args, **kwargs):                            │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_s │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl                                 │
│                                                                              │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                      │
│   1542 │   │                                                                 │
│   1543 │   │   try:                                                          │
│   1544 │   │   │   result = None                                             │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward                                        │
│                                                                              │
│   163 │   │   │   with torch.no_grad():                                      │
│   164 │   │   │   │   output = module._old_forward(*args, **kwargs)          │
│   165 │   │   else:                                                          │
│ ❱ 166 │   │   │   output = module._old_forward(*args, **kwargs)              │
│   167 │   │   return module._hf_hook.post_forward(module, output)            │
│   168 │                                                                      │
│   169 │   # Overriding a GraphModuleImpl forward freezes the forward call an │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/b │
│ itsandbytes/nn/modules.py:468 in forward                                     │
│                                                                              │
│   465 │   │   │   x = x.to(self.compute_dtype)                               │
│   466 │   │                                                                  │
│   467 │   │   bias = None if self.bias is None else self.bias.to(self.comput │
│ ❱ 468 │   │   out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_sta │
│   469 │   │                                                                  │
│   470 │   │   out = out.to(inp_dtype)                                        │
│   471                                                                        │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/b │
│ itsandbytes/autograd/_functions.py:579 in matmul_4bit                        │
│                                                                              │
│   576 │   │   │   │   out += bias                                            │
│   577 │   │   │   return out                                                 │
│   578 │   else:                                                              │
│ ❱ 579 │   │   return MatMul4Bit.apply(A, B, out, bias, quant_state)          │
│   580                                                                        │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/autograd/function.py:598 in apply                                       │
│                                                                              │
│   595 │   │   if not torch._C._are_functorch_transforms_active():            │
│   596 │   │   │   # See NOTE: [functorch vjp and autograd interaction]       │
│   597 │   │   │   args = _functorch.utils.unwrap_dead_wrappers(args)         │
│ ❱ 598 │   │   │   return super().apply(*args, **kwargs)  # type: ignore[misc │
│   599 │   │                                                                  │
│   600 │   │   if not is_setup_ctx_defined:                                   │
│   601 │   │   │   raise RuntimeError(                                        │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/b │
│ itsandbytes/autograd/_functions.py:509 in forward                            │
│                                                                              │
│   506 │   │                                                                  │
│   507 │   │   # 1. Dequantize                                                │
│   508 │   │   # 2. MatmulnN                                                  │
│ ❱ 509 │   │   output = torch.nn.functional.linear(A, F.dequantize_4bit(B, qu │
│   510 │   │                                                                  │
│   511 │   │   # 3. Save state                                                │
│   512 │   │   ctx.state = quant_state                                        │
╰──────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 172.00 MiB. GPU 
E0628 17:14:30.162000 22490380526784 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 17678) of binary: .../python
...

Important: I’m using:

  • a custom PyTorch training and evaluation loop, not the huggingface Trainer API.
  • bitsandbytes 4-bit quantization.

Things that I’ve tried:

  • Setting model.eval(). This seems to be causing the problem. When I replace it with model.train(), then memory stats are constant from step to step and I don’t run into memory issues.
  • gc.collect() and torch.cuda.empty_cache() before and after every eval step, using with torch.no_grad(). No positive effect.

Does anybody have any idea what might be causing the problem and/or how I could go about debugging it?

Since setting model.train() seems to work, would evaluating in training mode be an option if I can’t figure out a way to make this work in eval mode? If so, is there anything I have to keep in mind when doing that?

Please let me know if you need additional information or code. I haven’t gotten an actual minimal working example to work yet but if you need one, I’ll work on it.
Many thanks in advance, any comments would be much appreciated!

I’ve found the solution. Manually deleting the model output after every eval step with del output seems to fix the memory leak. I still don’t understand why this isn’t necessary during training though.

I’ll post the PyTorch forum post here which contains a reproduction.

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.