Accelerate Distributed Randomly Hangs

Hi I have a script that runs with the DataParralell trainer on a machine with 8 H100 GPUs (aws p5 vm). When we run the script it starts to randomly get stuck forever at some iteration relatively late in the process (between 2000 - 4000th iteration).
We start the script with the following command:

accelerate launch src/compactifai_back/healing/scripts/fine_tune_accelerate.py --config_file src/compactifai_back/healing/configs/mixtral_8x7b/config.yaml

The gpus are only at 30% memory occupied and util is at 0%.
The stack trace of the relevant processes looks the following:

 pgrep -P $(pgrep -o accelerate) | xargs -I {} py-spy dump --pid {}
Process 39: /usr/bin/python3.10 -u src/compactifai_back/healing/scripts/fine_tune_accelerate.py --config_file src/compactifai_back/healing/configs/mixtral_8x7b/config.yaml
Python v3.10.12 (/usr/bin/python3.10)

Thread 39 (idle): "MainThread"
    backward (torch/autograd/__init__.py:266)
    backward (torch/_tensor.py:522)
    backward (deepspeed/runtime/fp16/loss_scaler.py:63)
    backward (deepspeed/runtime/zero/stage3.py:2213)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/engine.py:1976)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (accelerate/utils/deepspeed.py:166)
    backward (accelerate/accelerator.py:2126)
    training_loop (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:410)
    training_function (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:540)
    main (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:583)
    <module> (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:587)
Thread 930 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 4067 (active)
    all_gather_into_tensor (torch/distributed/distributed_c10d.py:2709)
    wrapper (torch/distributed/c10d_logger.py:72)
    all_gather_into_tensor (deepspeed/comm/torch.py:219)
    _fn (torch/_dynamo/eval_frame.py:489)
    all_gather_into_tensor (deepspeed/comm/comm.py:305)
    log_wrapper (deepspeed/comm/comm.py:117)
    allgather_fn (deepspeed/comm/comm.py:320)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _dist_allgather_fn (deepspeed/runtime/zero/partition_parameters.py:93)
    all_gather_coalesced (deepspeed/runtime/zero/partition_parameters.py:1217)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    __all_gather_params_ (deepspeed/runtime/zero/partitioned_param_coordinator.py:463)
    __all_gather_params (deepspeed/runtime/zero/partitioned_param_coordinator.py:434)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    fetch_sub_module (deepspeed/runtime/zero/partitioned_param_coordinator.py:385)
    decorate_context (torch/utils/_contextlib.py:115)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _fn (torch/_dynamo/eval_frame.py:489)
    pre_sub_module_backward_function (deepspeed/runtime/zero/parameter_offload.py:474)
    decorate_context (torch/utils/_contextlib.py:115)
    _run_before_backward_function (deepspeed/runtime/zero/parameter_offload.py:339)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/zero/parameter_offload.py:358)
    apply (torch/autograd/function.py:289)
    backward (torch/autograd/__init__.py:266)
    backward (torch/utils/checkpoint.py:320)
    apply (torch/autograd/function.py:289)
Thread 4069 (idle)
Thread 4070 (idle)
Thread 4071 (idle)
Thread 4072 (idle)
Thread 4073 (idle)
Thread 4074 (idle)
Thread 4075 (idle)
Process 40: /usr/bin/python3.10 -u src/compactifai_back/healing/scripts/fine_tune_accelerate.py --config_file src/compactifai_back/healing/configs/mixtral_8x7b/config.yaml
Python v3.10.12 (/usr/bin/python3.10)

Thread 40 (idle): "MainThread"
    backward (torch/autograd/__init__.py:266)
    backward (torch/_tensor.py:522)
    backward (deepspeed/runtime/fp16/loss_scaler.py:63)
    backward (deepspeed/runtime/zero/stage3.py:2213)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/engine.py:1976)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (accelerate/utils/deepspeed.py:166)
    backward (accelerate/accelerator.py:2126)
    training_loop (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:410)
    training_function (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:540)
    main (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:583)
    <module> (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:587)
Thread 924 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 4040 (idle)
Thread 4044 (active)
    all_gather_into_tensor (torch/distributed/distributed_c10d.py:2709)
    wrapper (torch/distributed/c10d_logger.py:72)
    all_gather_into_tensor (deepspeed/comm/torch.py:219)
    _fn (torch/_dynamo/eval_frame.py:489)
    all_gather_into_tensor (deepspeed/comm/comm.py:305)
    log_wrapper (deepspeed/comm/comm.py:117)
    allgather_fn (deepspeed/comm/comm.py:320)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _dist_allgather_fn (deepspeed/runtime/zero/partition_parameters.py:93)
    all_gather_coalesced (deepspeed/runtime/zero/partition_parameters.py:1217)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    __all_gather_params_ (deepspeed/runtime/zero/partitioned_param_coordinator.py:463)
    __all_gather_params (deepspeed/runtime/zero/partitioned_param_coordinator.py:434)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    fetch_sub_module (deepspeed/runtime/zero/partitioned_param_coordinator.py:385)
    decorate_context (torch/utils/_contextlib.py:115)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _fn (torch/_dynamo/eval_frame.py:489)
    pre_sub_module_backward_function (deepspeed/runtime/zero/parameter_offload.py:474)
    decorate_context (torch/utils/_contextlib.py:115)
    _run_before_backward_function (deepspeed/runtime/zero/parameter_offload.py:339)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/zero/parameter_offload.py:358)
    apply (torch/autograd/function.py:289)
    backward (torch/autograd/__init__.py:266)
    backward (torch/utils/checkpoint.py:320)
    apply (torch/autograd/function.py:289)
Thread 4047 (idle)
Thread 4049 (idle)
Thread 4051 (idle)
Thread 4053 (idle)
Thread 4055 (idle)
Thread 4057 (idle)
Process 41: /usr/bin/python3.10 -u src/compactifai_back/healing/scripts/fine_tune_accelerate.py --config_file src/compactifai_back/healing/configs/mixtral_8x7b/config.yaml
Python v3.10.12 (/usr/bin/python3.10)

Thread 41 (idle): "MainThread"
    backward (torch/autograd/__init__.py:266)
    backward (torch/_tensor.py:522)
    backward (deepspeed/runtime/fp16/loss_scaler.py:63)
    backward (deepspeed/runtime/zero/stage3.py:2213)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/engine.py:1976)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (accelerate/utils/deepspeed.py:166)
    backward (accelerate/accelerator.py:2126)
    training_loop (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:410)
    training_function (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:540)
    main (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:583)
    <module> (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:587)
Thread 929 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 4043 (idle)
Thread 4048 (idle)
Thread 4050 (active)
    all_gather_into_tensor (torch/distributed/distributed_c10d.py:2709)
    wrapper (torch/distributed/c10d_logger.py:72)
    all_gather_into_tensor (deepspeed/comm/torch.py:219)
    _fn (torch/_dynamo/eval_frame.py:489)
    all_gather_into_tensor (deepspeed/comm/comm.py:305)
    log_wrapper (deepspeed/comm/comm.py:117)
    allgather_fn (deepspeed/comm/comm.py:320)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _dist_allgather_fn (deepspeed/runtime/zero/partition_parameters.py:93)
    all_gather_coalesced (deepspeed/runtime/zero/partition_parameters.py:1217)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    __all_gather_params_ (deepspeed/runtime/zero/partitioned_param_coordinator.py:463)
    __all_gather_params (deepspeed/runtime/zero/partitioned_param_coordinator.py:434)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    fetch_sub_module (deepspeed/runtime/zero/partitioned_param_coordinator.py:385)
    decorate_context (torch/utils/_contextlib.py:115)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _fn (torch/_dynamo/eval_frame.py:489)
    pre_sub_module_backward_function (deepspeed/runtime/zero/parameter_offload.py:474)
    decorate_context (torch/utils/_contextlib.py:115)
    _run_before_backward_function (deepspeed/runtime/zero/parameter_offload.py:339)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/zero/parameter_offload.py:358)
    apply (torch/autograd/function.py:289)
    backward (torch/autograd/__init__.py:266)
    backward (torch/utils/checkpoint.py:320)
    apply (torch/autograd/function.py:289)
Thread 4052 (idle)
Thread 4054 (idle)
Thread 4056 (idle)
Thread 4058 (idle)
Thread 4059 (idle)
Process 42: /usr/bin/python3.10 -u src/compactifai_back/healing/scripts/fine_tune_accelerate.py --config_file src/compactifai_back/healing/configs/mixtral_8x7b/config.yaml
Python v3.10.12 (/usr/bin/python3.10)

Thread 42 (idle): "MainThread"
    backward (torch/autograd/__init__.py:266)
    backward (torch/_tensor.py:522)
    backward (deepspeed/runtime/fp16/loss_scaler.py:63)
    backward (deepspeed/runtime/zero/stage3.py:2213)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/engine.py:1976)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (accelerate/utils/deepspeed.py:166)
    backward (accelerate/accelerator.py:2126)
    training_loop (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:410)
    training_function (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:540)
    main (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:583)
    <module> (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:587)
Thread 928 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 4076 (idle)
Thread 4077 (idle)
Thread 4078 (idle)
Thread 4079 (active)
    all_gather_into_tensor (torch/distributed/distributed_c10d.py:2709)
    wrapper (torch/distributed/c10d_logger.py:72)
    all_gather_into_tensor (deepspeed/comm/torch.py:219)
    _fn (torch/_dynamo/eval_frame.py:489)
    all_gather_into_tensor (deepspeed/comm/comm.py:305)
    log_wrapper (deepspeed/comm/comm.py:117)
    allgather_fn (deepspeed/comm/comm.py:320)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _dist_allgather_fn (deepspeed/runtime/zero/partition_parameters.py:93)
    all_gather_coalesced (deepspeed/runtime/zero/partition_parameters.py:1217)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    __all_gather_params_ (deepspeed/runtime/zero/partitioned_param_coordinator.py:463)
    __all_gather_params (deepspeed/runtime/zero/partitioned_param_coordinator.py:434)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    fetch_sub_module (deepspeed/runtime/zero/partitioned_param_coordinator.py:385)
    decorate_context (torch/utils/_contextlib.py:115)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _fn (torch/_dynamo/eval_frame.py:489)
    pre_sub_module_backward_function (deepspeed/runtime/zero/parameter_offload.py:474)
    decorate_context (torch/utils/_contextlib.py:115)
    _run_before_backward_function (deepspeed/runtime/zero/parameter_offload.py:339)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/zero/parameter_offload.py:358)
    apply (torch/autograd/function.py:289)
    backward (torch/autograd/__init__.py:266)
    backward (torch/utils/checkpoint.py:320)
    apply (torch/autograd/function.py:289)
Thread 4080 (idle)
Thread 4081 (idle)
Thread 4082 (idle)
Thread 4083 (idle)
Process 43: /usr/bin/python3.10 -u src/compactifai_back/healing/scripts/fine_tune_accelerate.py --config_file src/compactifai_back/healing/configs/mixtral_8x7b/config.yaml
Python v3.10.12 (/usr/bin/python3.10)

Thread 43 (idle): "MainThread"
    backward (torch/autograd/__init__.py:266)
    backward (torch/_tensor.py:522)
    backward (deepspeed/runtime/fp16/loss_scaler.py:63)
    backward (deepspeed/runtime/zero/stage3.py:2213)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/engine.py:1976)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (accelerate/utils/deepspeed.py:166)
    backward (accelerate/accelerator.py:2126)
    training_loop (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:410)
    training_function (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:540)
    main (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:583)
    <module> (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:587)
Thread 925 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 4036 (idle)
Thread 4037 (idle)
Thread 4038 (idle)
Thread 4039 (idle)
Thread 4041 (active)
    partition_grads (deepspeed/runtime/zero/stage3.py:1467)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    __reduce_and_partition_ipg_grads (deepspeed/runtime/zero/stage3.py:1276)
    decorate_context (torch/utils/_contextlib.py:115)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    reduce_independent_p_g_buckets_and_remove_grads (deepspeed/runtime/zero/stage3.py:1224)
    reduce_ready_partitions_and_remove_grads (deepspeed/runtime/zero/stage3.py:1483)
    reduce_partition_and_remove_grads (deepspeed/runtime/zero/stage3.py:1132)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (torch/autograd/__init__.py:266)
    backward (torch/utils/checkpoint.py:320)
    apply (torch/autograd/function.py:289)
Thread 4042 (idle)
Thread 4045 (idle)
Thread 4046 (idle)
Process 44: /usr/bin/python3.10 -u src/compactifai_back/healing/scripts/fine_tune_accelerate.py --config_file src/compactifai_back/healing/configs/mixtral_8x7b/config.yaml
Python v3.10.12 (/usr/bin/python3.10)

Thread 44 (idle): "MainThread"
    backward (torch/autograd/__init__.py:266)
    backward (torch/_tensor.py:522)
    backward (deepspeed/runtime/fp16/loss_scaler.py:63)
    backward (deepspeed/runtime/zero/stage3.py:2213)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/engine.py:1976)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (accelerate/utils/deepspeed.py:166)
    backward (accelerate/accelerator.py:2126)
    training_loop (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:410)
    training_function (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:540)
    main (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:583)
    <module> (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:587)
Thread 927 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 4084 (idle)
Thread 4085 (idle)
Thread 4087 (idle)
Thread 4089 (idle)
Thread 4091 (idle)
Thread 4093 (active)
    all_gather_into_tensor (torch/distributed/distributed_c10d.py:2709)
    wrapper (torch/distributed/c10d_logger.py:72)
    all_gather_into_tensor (deepspeed/comm/torch.py:219)
    _fn (torch/_dynamo/eval_frame.py:489)
    all_gather_into_tensor (deepspeed/comm/comm.py:305)
    log_wrapper (deepspeed/comm/comm.py:117)
    allgather_fn (deepspeed/comm/comm.py:320)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _dist_allgather_fn (deepspeed/runtime/zero/partition_parameters.py:93)
    all_gather_coalesced (deepspeed/runtime/zero/partition_parameters.py:1217)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    __all_gather_params_ (deepspeed/runtime/zero/partitioned_param_coordinator.py:463)
    __all_gather_params (deepspeed/runtime/zero/partitioned_param_coordinator.py:434)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    fetch_sub_module (deepspeed/runtime/zero/partitioned_param_coordinator.py:385)
    decorate_context (torch/utils/_contextlib.py:115)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _fn (torch/_dynamo/eval_frame.py:489)
    pre_sub_module_backward_function (deepspeed/runtime/zero/parameter_offload.py:474)
    decorate_context (torch/utils/_contextlib.py:115)
    _run_before_backward_function (deepspeed/runtime/zero/parameter_offload.py:339)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/zero/parameter_offload.py:358)
    apply (torch/autograd/function.py:289)
    backward (torch/autograd/__init__.py:266)
    backward (torch/utils/checkpoint.py:320)
    apply (torch/autograd/function.py:289)
Thread 4095 (idle)
Thread 4097 (idle)
Process 45: /usr/bin/python3.10 -u src/compactifai_back/healing/scripts/fine_tune_accelerate.py --config_file src/compactifai_back/healing/configs/mixtral_8x7b/config.yaml
Python v3.10.12 (/usr/bin/python3.10)

Thread 45 (idle): "MainThread"
    backward (torch/autograd/__init__.py:266)
    backward (torch/_tensor.py:522)
    backward (deepspeed/runtime/fp16/loss_scaler.py:63)
    backward (deepspeed/runtime/zero/stage3.py:2213)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/engine.py:1976)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (accelerate/utils/deepspeed.py:166)
    backward (accelerate/accelerator.py:2126)
    training_loop (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:410)
    training_function (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:540)
    main (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:583)
    <module> (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:587)
Thread 926 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 4086 (idle)
Thread 4088 (idle)
Thread 4090 (idle)
Thread 4092 (idle)
Thread 4094 (idle)
Thread 4096 (idle)
Thread 4098 (active)
    all_gather_into_tensor (torch/distributed/distributed_c10d.py:2709)
    wrapper (torch/distributed/c10d_logger.py:72)
    all_gather_into_tensor (deepspeed/comm/torch.py:219)
    _fn (torch/_dynamo/eval_frame.py:489)
    all_gather_into_tensor (deepspeed/comm/comm.py:305)
    log_wrapper (deepspeed/comm/comm.py:117)
    allgather_fn (deepspeed/comm/comm.py:320)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _dist_allgather_fn (deepspeed/runtime/zero/partition_parameters.py:93)
    all_gather_coalesced (deepspeed/runtime/zero/partition_parameters.py:1217)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    __all_gather_params_ (deepspeed/runtime/zero/partitioned_param_coordinator.py:463)
    __all_gather_params (deepspeed/runtime/zero/partitioned_param_coordinator.py:434)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    fetch_sub_module (deepspeed/runtime/zero/partitioned_param_coordinator.py:385)
    decorate_context (torch/utils/_contextlib.py:115)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _fn (torch/_dynamo/eval_frame.py:489)
    pre_sub_module_backward_function (deepspeed/runtime/zero/parameter_offload.py:474)
    decorate_context (torch/utils/_contextlib.py:115)
    _run_before_backward_function (deepspeed/runtime/zero/parameter_offload.py:339)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/zero/parameter_offload.py:358)
    apply (torch/autograd/function.py:289)
    backward (torch/autograd/__init__.py:266)
    backward (torch/utils/checkpoint.py:320)
    apply (torch/autograd/function.py:289)
Thread 4099 (idle)
Process 46: /usr/bin/python3.10 -u src/compactifai_back/healing/scripts/fine_tune_accelerate.py --config_file src/compactifai_back/healing/configs/mixtral_8x7b/config.yaml
Python v3.10.12 (/usr/bin/python3.10)

Thread 46 (idle): "MainThread"
    backward (torch/autograd/__init__.py:266)
    backward (torch/_tensor.py:522)
    backward (deepspeed/runtime/fp16/loss_scaler.py:63)
    backward (deepspeed/runtime/zero/stage3.py:2213)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/engine.py:1976)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (accelerate/utils/deepspeed.py:166)
    backward (accelerate/accelerator.py:2126)
    training_loop (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:410)
    training_function (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:540)
    main (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:583)
    <module> (src/compactifai_back/healing/scripts/fine_tune_accelerate.py:587)
Thread 931 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 4060 (idle)
Thread 4061 (idle)
Thread 4062 (idle)
Thread 4063 (idle)
Thread 4064 (idle)
Thread 4065 (idle)
Thread 4066 (idle)
Thread 4068 (active)
    all_gather_into_tensor (torch/distributed/distributed_c10d.py:2709)
    wrapper (torch/distributed/c10d_logger.py:72)
    all_gather_into_tensor (deepspeed/comm/torch.py:219)
    _fn (torch/_dynamo/eval_frame.py:489)
    all_gather_into_tensor (deepspeed/comm/comm.py:305)
    log_wrapper (deepspeed/comm/comm.py:117)
    allgather_fn (deepspeed/comm/comm.py:320)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _dist_allgather_fn (deepspeed/runtime/zero/partition_parameters.py:93)
    all_gather_coalesced (deepspeed/runtime/zero/partition_parameters.py:1217)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    __all_gather_params_ (deepspeed/runtime/zero/partitioned_param_coordinator.py:463)
    __all_gather_params (deepspeed/runtime/zero/partitioned_param_coordinator.py:434)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    fetch_sub_module (deepspeed/runtime/zero/partitioned_param_coordinator.py:385)
    decorate_context (torch/utils/_contextlib.py:115)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _fn (torch/_dynamo/eval_frame.py:489)
    pre_sub_module_backward_function (deepspeed/runtime/zero/parameter_offload.py:474)
    decorate_context (torch/utils/_contextlib.py:115)
    _run_before_backward_function (deepspeed/runtime/zero/parameter_offload.py:339)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    backward (deepspeed/runtime/zero/parameter_offload.py:358)
    apply (torch/autograd/function.py:289)
    backward (torch/autograd/__init__.py:266)
    backward (torch/utils/checkpoint.py:320)
    apply (torch/autograd/function.py:289)