Initialize model with empty weight causes OOM with offloading to disk

Reproduction

I am trying to run Deepseek R1 in bf16 on 4 A100 GPU (with offloading to CPU 400GB memory and disk 5TB). I am running this job using slurm workload manager.
This is how I am loading the model

    accelerator = Accelerator()
    torch.set_default_dtype(torch.bfloat16)
    torch.manual_seed(42)

    with open(config) as f:
        args = ModelArgs(**json.load(f))   

    with init_empty_weights():
        model = Transformer(args)

    model = load_checkpoint_and_dispatch(
        model,
        os.path.join(ckpt_path, "model0-mp1.safetensors"),  # using convert.py from DeepSeek-v3, combine all 163 safetensors to one
        device_map="auto",
        offload_folder="/scratch/rr4549/offload",
        offload_buffers=True,
        offload_state_dict=True,
        max_memory={0:"70GB", 1:"70GB", 2:"70GB", 3:"70GB", "cpu":"300GB"},
        dtype=torch.bfloat16
    )

Running the script by
accelerate launch --multi_gpu --num_processes 4 --num_machines 1 generate2.py --ckpt-path /scratch/rr4549/DeepSeek-R1-Demo/ --config configs/config_671B.json --interactive --temperature 0.7 --max-new-tokens 1

The script crashes when trying to initialize model with empty weights
Error:

W0131 17:51:41.347000 3004393 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 3004406 closing signal SIGTERM
W0131 17:51:41.352000 3004393 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 3004407 closing signal SIGTERM
W0131 17:51:41.354000 3004393 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 3004408 closing signal SIGTERM
E0131 17:51:42.033000 3004393 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -9) local_rank: 0 (pid: 3004405) of binary: /ext3/miniforge3/bin/python3.12
Traceback (most recent call last):
  File "/ext3/miniforge3/bin/accelerate", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/ext3/miniforge3/lib/python3.12/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/ext3/miniforge3/lib/python3.12/site-packages/accelerate/commands/launch.py", line 1163, in launch_command
    multi_gpu_launcher(args)
  File "/ext3/miniforge3/lib/python3.12/site-packages/accelerate/commands/launch.py", line 792, in multi_gpu_launcher
    distrib_run.run(args)
  File "/ext3/miniforge3/lib/python3.12/site-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "/ext3/miniforge3/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ext3/miniforge3/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
========================================================
generate2.py FAILED
--------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
--------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-01-31_17:51:41
  host      : ***
  rank      : 0 (local_rank: 0)
  exitcode  : -9 (pid: 3004405)
  error_file: <N/A>
  traceback : Signal 9 (SIGKILL) received by PID 3004405
========================================================
slurmstepd: error: Detected 1 oom_kill event in StepId=56705791.batch. Some of the step tasks have been OOM Killed.

Expected behavior

The model weights can offload to the CPU and disk and do model inference.

1 Like