Why does moving ML model initialization into a function prevent GPU OOM errors when del, gc.collect(), and torch.cuda.empty_cache() fail?

for model_name in model_list:
    model = LLM(model_name, trust_remote_code=True)
    results = evaluate_model(model, task)
    del model
    gc.collect()
    torch.cuda.empty_cache()

Despite explicitly deleting the model object, calling gc.collect(), and clearing the CUDA cache with torch.cuda.empty_cache(), I still encountered GPU out-of-memory (OOM) errors.

After experimenting, I moved the model instantiation into a separate function, like so:

def do_lm_eval(model_name: str, task: str) -> dict:
    gc.collect()
    torch.cuda.empty_cache()
    model = LLM(model_name, trust_remote_code=True)
    results = evaluate_model(model, task)
    del model
    gc.collect()
    torch.cuda.empty_cache()
    return results

for model_name in model_list:
    results = do_lm_eval(model_name, task)

Surprisingly, this resolved the OOM errors completely. My questions are:

Why does moving the model instantiation into its own function prevent GPU OOM errors, even though I was already using del, gc.collect(), and torch.cuda.empty_cache() in the original code?
Is there something about Python’s memory management or PyTorch’s interaction with CUDA that makes scoping within a function more effective?
Are there additional best practices for managing memory when iterating over large models on GPU?
Additional context:

  • The issue occurs when evaluating multiple checkpoints of a large model.
  • I suspect memory fragmentation or lingering references may play a role, but gc.collect() doesn’t seem to fix it.
  • Using a dedicated function appears to clean up resources more effectively, but I’m unsure why.
  • Any insights into why explicit cleanup doesn’t always work would be greatly appreciated!

Note:

I didn’t try this:

INFO 12-05 13:11:12 model_runner.py:1404] If out-of-memory error occurs during cudagraph capture, consider decreasing gpu_memory_utilization or switching to eager mode. You can also reduce the max_num_seqs as needed to decrease memory usage.

error

ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.

Full code:

# using lm-harness/eval: https://chatgpt.com/c/67477267-a740-8001-894a-c5f22b24cc5f
from socket import gethostname
import os
import gc
import torch
from vllm import LLM, SamplingParams
from lm_eval import evaluator, tasks
from lm_eval.api.model import LM
import lm_eval.evaluator 
from os.path import expanduser
import fire
import wandb
from datetime import datetime
from typing import List, Tuple

_STOP_TOKENS: list[str] = ["Solution:", "Problem:", "Question:", "USER:", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
print(f'Original stop toks I had once: {len(_STOP_TOKENS)=} {_STOP_TOKENS=}')

#     raw_str_2_train_str = lambda examples : {'text': [f'problem: {prob}\n\nsolution: {sol}' for prob, sol in zip(examples['problem'], examples['solution'])]}
#     new, v2, no trian yet raw_str_2_train_str = lambda examples : {'text': [f'Problem:\n{prob}\n\nSolution:\n{sol}' for prob, sol in zip(examples['problem'], examples['solution'])]}
STOP_TOKENS: list[str] = ["problem:", "problem: ", "problem:\n", "Problem:", "Problem: ", "Problem:\n", "Question:", "USER:", "USER"]
print(f'New stop tokens: {len(STOP_TOKENS)=} {STOP_TOKENS=}')

class MyLM(LM):
    def __init__(self, model, batch_size: int = 16):
        self.model = model
        super().__init__()
        self.batch_size = batch_size

    def loglikelihood(self, requests: List) -> List[Tuple[float, bool]]:
        results = []
        for request in requests:
            context, continuation = request.args()
            logprob = self.model.compute_loglikelihood(context, continuation)
            isgreedy = self.model.is_greedy(context, continuation)
            results.append((logprob, isgreedy))
        return results

    def loglikelihood_rolling(self, requests: List) -> List[float]:
        results = []
        for request in requests:
            context, = request.args()
            logprob = self.model.compute_rolling_loglikelihood(context)
            results.append(logprob)
        return results

    def generate_until(self, requests: List) -> List[str]:
        print(f'--- {self.generate_until=}')
        # params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=2048, stop='Problem:\n')
        # params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=2048, stop=STOP_TOKENS)
        # params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=512, stop=STOP_TOKENS)
        params = SamplingParams(temperature=0, top_p=1, max_tokens=512, stop=STOP_TOKENS)
        prompts: List[str] = [request.args[0] for request in requests]
        outputs: list = self.model.generate(prompts, params)
        results: list[str] = [output.outputs[0].text for output in outputs]
        # print prompt result, what is going on during eval?
        assert len(prompts) == len(results), f'Fatal error: {wandb.run.alert(title="error during eval", text="len(prompts) != len(results)")}'
        for req_idx, (prompt, result) in enumerate(zip(prompts, results)):
            print('--' * 40 + 'start of problem --> model response')
            print(f'Request index: {req_idx}')
            print(f'{"--"*20}\nPrompt:\n{prompt}\n')
            print(f'{"--"*20}\nResult:\n{result}\n')
        print(f'--- {self.generate_until=}')
        return results

def do_lm_eval(model_name: str, task: str, kwargs: dict) -> dict: 
    gc.collect()
    torch.cuda.empty_cache()
    
    model = LLM(model_name, trust_remote_code=True)
    lm_obj = MyLM(model=model)
    task_manager = tasks.TaskManager()
    results = lm_eval.evaluator.simple_evaluate(
        model=lm_obj,
        tasks=[task],
        task_manager=task_manager,
        write_out=False,
        limit=kwargs.get('limit', 1), # limit the number of examples, if <1 then interpreted as %, default None --> all task benchmark
        random_seed=None,
        numpy_random_seed=None,
        torch_random_seed=None,
        fewshot_random_seed=None
    )

    del model
    del lm_obj
    gc.collect()
    torch.cuda.empty_cache()
    return results

def main(**kwargs):
    gc.collect()
    torch.cuda.empty_cache()

    print(f'{"__"*32} Start of eval main: {main=}')
    print(f'{STOP_TOKENS=}')
    # task = kwargs.get('task', 'putnam_axiom_original')
    # task = kwargs.get('task', 'putnam_axiom_variations')
    task = kwargs.get('task', 'putnam_axiom_53')
    print(f'{task=}')
    directory_path = expanduser("~/data/runs_logic_cont/run_2024_m11_d20_t21h_22m_22s") # putnam-axiom 53 train 20 epochs
    directory_path = expanduser("~/data/runs_logic_cont/run_2024_m11_d23_t21h_55m_07s") # putnam-axiom 53 train 100 epochs
    directory_path = expanduser("~/data/runs_logic_cont/run_2024_m12_d05_t12h_07m_56s")
    # directory_path = expanduser(kwargs.get('model_name_or_path', 'google/gemma-2-2b'))
    print(f'{directory_path=}')
    # Extract and sort the model_list by checkpoint index
    model_list = sorted(
        [
            os.path.join(directory_path, entry)
            for entry in os.listdir(directory_path)
            if os.path.isdir(os.path.join(directory_path, entry)) and entry.startswith("checkpoint-")
        ],
        key=lambda x: int(x.split('/')[-1].split('-')[-1])  # Extract the checkpoint index for sorting
    )
    print(f'{model_list}')
    print("\n".join(model_list))
    print(f'{len(model_list)=} (should be same as the expect steps/epochs +1 roughly)')
    # since we might be running multiple evals at once, the next time we run an eval we add it to the config, since wandb doesn't let you update the config if a key already has a value
    if task in wandb.config:
        next_task: list = wandb.config['task'].append(task)
        wandb.config.update({"model_list": model_list, 'task': next_task}, allow_val_change=True)
    else:
        wandb.config.update({"model_list": model_list, 'task': [task]}, allow_val_change=True)
    print(f'{wandb.config=}')
    # - Start eval run
    wandb.run.define_metric(f"{task}/eval_bench/accuracy", 
                            step_metric=f"{task}/eval_bench/checkpoint_idx")
    wandb.run.define_metric(f"{task}/eval_bench/checkpoint_idx") 
    model_2_accuracy = {model_name: None for model_name in model_list}
    print(f'{model_2_accuracy}')
    accs: list[float] = []
    for model_name in model_list:
        torch.cuda.empty_cache()
        gc.collect()

        print(f"{'=' * 100}")
        print(f'running model with name: {model_name}')
        print(f"{'=' * 100}")

        # model = LLM(model_name, trust_remote_code=True)
        # lm_obj = MyLM(model=model)
        # task_manager = tasks.TaskManager()
        # results = lm_eval.evaluator.simple_evaluate(
        #     model=lm_obj,
        #     tasks=[task],
        #     task_manager=task_manager,
        #     write_out=False,
        #     limit=kwargs.get('limit', 1), # limit the number of examples, if <1 then interpreted as %, default None --> all task benchmark
        #     random_seed=None,
        #     numpy_random_seed=None,
        #     torch_random_seed=None,
        #     fewshot_random_seed=None
        # )
        results = do_lm_eval(model_name, task, kwargs)
        print(f"arguments: {results['samples'][task][0]['arguments'][0][0]=}")
        print(f"resps: {results['samples'][task][0]['resps'][0][0]=}")
        print(f'{results.keys()=}')
        print(f'{results["results"][task].keys()=}')

        print("results:", results["results"][task].get("exact_match,none", None))
        accuracy = results["results"][task].get("exact_match,none", None) 
        model_2_accuracy[model_name] = accuracy
        accs.append(float(accuracy))
        checkpoint_idx = int(model_name.split('/')[-1].split('-')[-1])  # e.g., 'checkpoint-70'
        
        # Log accuracy for each checkpoint
        print(f'Checkpoint idx: {checkpoint_idx=}')
        print(f"Accuracy for that checkpoint: {accuracy=} ({checkpoint_idx=})")
        print(f"Checkpoint full name: {model_name=}")
        wandb.log({f"{task}/eval_bench/checkpoint_idx": checkpoint_idx, 
                   f"{task}/eval_bench/accuracy": accuracy})
        # del model
        # del lm_obj
        gc.collect()
        torch.cuda.empty_cache()
        # gc.collect()
        # torch.cuda.empty_cache()
    print(f'{model_2_accuracy=}\n{accs=}')
    print(f"{'=' * 100}")
    return {'model_2_accuracy': model_2_accuracy, 'accs': accs}

def _main(**kwargs):
    today = datetime.now().strftime('%Y_m%m_d%d_t%Hh_%Mm_%Ss')
    run_name = f'{today}'
    run = wandb.init(
        # mode=kwargs.get('mode', 'dryrun'), 
        mode=kwargs.get('mode', 'online'), 
        project="putnam-axiom", 
        name=run_name, 
        save_code=True, 
        config=kwargs | {'hostname': gethostname()}
    )
    kwargs = kwargs | {'today': today}
    os.environ['CUDA_VISIBLE_DEVICES'] = str(6)
    print("--> WARNING/REMINDER: cude device harcoded in script!\n"*10)
    main(**kwargs)
    print(f'{run.get_url()=}')
    wandb.finish()

if __name__ == "__main__":
    import time
    start_time = time.time()
    fire.Fire(_main)
    elapsed_time = time.time() - start_time
    print(f"Time taken: {elapsed_time:.2f} seconds, or {elapsed_time / 60:.2f} minutes, or {elapsed_time / 3600:.2f} hours.\a")

Note I’ve had to do this trick too for training scripts too, full script:

from datetime import datetime
from typing import Optional
import random
import torch
from transformers import PushToHubCallback
from transformers import get_cosine_schedule_with_warmup
from trl import SFTConfig, SFTTrainer
import os
import fire
import wandb
import sys

from train.callbacks import GenCallbackWithHFGenerate
from train.data import load_math_style_dataset, print_first_example_after_decode
import train.models

from train.utils import seed_everything

def main(**config):
    # -- Seed everything
    seed_everything(seed=config.get('seed', 0))
    
    # -- HF login
    from huggingface_hub import login
    token = open(os.path.expanduser("~/keys/master_hf_token.txt")).read().strip()
    login(token=token)

    # -- Get model
    model, tok = train.models.load_mdl_and_tok(config.get('pretrained_model_name_or_path', 'google/gemma-2-2b')) 
    # model, tok = train.models.load_mdl_and_tok(config.get('pretrained_model_name_or_path', 'meta-llama/Llama-3.1-8B')) 

    # -- Load datasets
    ds_name_or_path = config.get('ds_name_or_path', 'Putnam-AXIOM/putnam-axiom-dataset')
    train_split, val_split = config.get('train_split', 'func_original_53_10_30_2024'), config.get('val_split', 'func_variations_265_11_23_2024')
    print(f'\n---> {ds_name_or_path=} {train_split=} {val_split=}\n')
    train_dataset = load_math_style_dataset(ds_name_or_path, tok, config.get('max_seq_length', 512), end=1, split=train_split)
    print_first_example_after_decode(train_dataset, tok)
    # eval_dataset = load_math_style_dataset(ds_name_or_path, tok, config.get('max_seq_length', 512), end=15, split=val_split)
    eval_dataset = train_dataset
    print(f'{len(train_dataset)=}\n{len(eval_dataset)=}')
    wandb.config.update({'dataset': f'{ds_name_or_path} ({train_split=} {val_split=})'})

    # -- Prepare output directory
    today: str = datetime.now().strftime('%Y_m%m_d%d_t%Hh_%Mm_%Ss')
    output_dir: str = os.path.expanduser(f"~/data/runs_logic_cont/run_{config.get('today', today)}")
    print(f'{output_dir=}')
    
    # Save the initial model and tokenizer as checkpoint-0
    initial_checkpoint_dir = os.path.join(output_dir, "checkpoint-0")
    os.makedirs(initial_checkpoint_dir, exist_ok=True)
    print(f"Saving initial checkpoint and tokenizer at {initial_checkpoint_dir}")
    model.save_pretrained(initial_checkpoint_dir)
    tok.save_pretrained(initial_checkpoint_dir)

    # -- Train model
    # max_steps = 50  # Limit fine-tuning to a few steps
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(random.randint(0, 7))
    # config = {'max_steps': 2, 'eval_steps': 1, 'logging_steps': 1, 
    #           'save_strategy': 'steps', 'save_steps': 1, 'eval_strategy': 'steps'}
    # config = config | {'CUDA_VISIBLE_DEVICES': os.environ.get('CUDA_VISIBLE_DEVICES', 'maybe 0')}
    training_args = SFTConfig(
        max_steps=config.get('max_steps', 30),
        # --
        output_dir=output_dir,
        bf16=torch.cuda.is_bf16_supported(),
        fp16=not torch.cuda.is_bf16_supported(),
        # -- logging opts
        save_steps=config.get('save_steps', 5), 
        save_strategy=config.get('save_strategy', 'steps'),
        eval_on_start=config.get('eval_on_start', True),
        evaluation_strategy=config.get('eval_strategy', 'steps'), 
        eval_steps=config.get('eval_steps', 1), 
        logging_first_step=config.get('logging_first_step', True), # Default to False, unsure 100% what this does but looks like a good idea
        logging_strategy=config.get('logging_strategy', 'steps'),
        logging_steps=config.get('logging_steps', 1),
        # --
        num_train_epochs=config.get('num_train_epochs', 10),
        max_seq_length=config.get('max_seq_length', 512),
        per_device_train_batch_size=config.get('batch_size', 2),
        gradient_accumulation_steps=config.get('gradient_accumulation_steps', 2),
    )
    # Calculate Total Steps
    steps_per_epoch = (len(train_dataset) // training_args.per_device_train_batch_size) // training_args.gradient_accumulation_steps
    total_steps = steps_per_epoch * training_args.num_train_epochs
    print(f'{steps_per_epoch=}')

    # Optimizer and Scheduler
    # optimizer_grouped_parameters = [{'params': [p for p in model.parameters()], 'weight_decay': 1e-4}]
    optimizer_grouped_parameters = [{'params': [p for p in model.parameters()], 'weight_decay': 0}]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=config.get('learning_rate', 1e-5))

    # Add Cosine Learning Rate Scheduler
    # warmup_steps = int(0.01 * total_steps)  # Warm-up for 1% of total steps
    warmup_steps = 0
    scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )
    scheduler = None
    print(f'{total_steps=} {warmup_steps=}')
    trainer = SFTTrainer(
        model=model,
        tokenizer=tok,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        args=training_args,
        optimizers=(optimizer, scheduler),
        callbacks=[GenCallbackWithHFGenerate(model, tok)]
    )
    print(f"\nStarting fine-tuning...")
    trainer.train()
    # - end run
    return os.path.expanduser(output_dir)

def run_eval_logic_contamination(output_dir: str):
    """
    Runs the eval_logic_contamination.py script with the specified output directory.

    Args:
        output_dir (str): The directory where the model is saved, expanded using `os.path.expanduser`.
    """
    import gc
    torch.cuda.empty_cache()
    gc.collect()
    output_dir = os.path.expanduser(output_dir)  # Ensure `output_dir` is expanded 
    from eval_logic_contamination import main
    task='putnam_axiom_53'
    res: dict = main(model_name_or_path=output_dir, task=task)
    print(f'Results for {task=}: {res}')
    print(res)
    # task='putnam_axiom_53' # for debugging
    task='putnam_axiom_variations'
    res: dict = main(model_name_or_path=output_dir, task=task)
    print(f'Results for {task=}: {res}')
    print(res)
    # wandb.run.define_metric("eval/accuracy", step_metric="eval/checkpoint_idx")
    # wandb.run.define_metric("eval/checkpoint_idx") 
    # for idx, acc in [(10,5), (20,10), (30,15)]:
    #     wandb.log({'eval/accuracy': acc, 'eval/checkpoint_idx': idx})

def _main(**kwargs):
    from datetime import datetime
    today = datetime.now().strftime('%Y_m%m_d%d_t%Hh_%Mm_%Ss') # eg '2024_m01_d22_t13h_00m_30s'
    run_name = f'{today}' 
    kwargs = kwargs | {'today': today}
    # run = wandb.init(mode=kwargs.get('mode', 'dryrun'), project="putnam-axiom", name=run_name, save_code=True, config=kwargs)
    run = wandb.init(mode=kwargs.get('mode', 'online'), project="putnam-axiom", name=run_name, save_code=True, config=kwargs)
    # wandb.run.log_code(f"./{os.path.basename(__file__)}") # maybe logscode immediately # ref: https://stackoverflow.com/questions/79256112/how-to-log-only-the-current-script-file-to-wb-code-panel-immediately
    # wandb.config.update()
    os.environ['CUDA_VISIBLE_DEVICES'] = str(6)
    print("--> WARNING/REMINDER: cude device harcoded in script!\n"*10)
    output_dir = main(**kwargs)
    run_eval_logic_contamination(output_dir)
    # from train.utils import copy_to_dfs
    # copy_to_dfs(output_dir)
    run.alert(title="Run Completed", text=f"Run finished, run url: {run.get_url()}")
    print(f'{run.get_url()=}')
    wandb.finish()

if __name__ == "__main__":
    import time
    start_time = time.time()
    fire.Fire(_main)
    print(f"Time taken: {time.time() - start_time:.2f} seconds, or {(time.time() - start_time) / 60:.2f} minutes, or {(time.time() - start_time) / 3600:.2f} hours.\a")

cross: python 3.x - Why does moving ML model initialization into a function prevent GPU OOM errors when del, gc.collect(), and torch.cuda.empty_cache() fail? - Stack Overflow

1 Like