Cannot run multi GPU training on SLURM

Hi,
Im currently trying to setup multi gpu training using accelerate with the for training GRPO from the TRL library. Single GPU training works, but as soon as I go to multi GPU, everything fails and i cant figure out why.

The error:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/mgroepl/hFace/test.py", line 92, in <module>
[rank1]:     trainer.train()
[rank1]:   File "/itet-stor/mgroepl/net_scratch/conda_envs/hFace/lib/python3.12/site-packages/transformers/trainer.py", line 2241, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/itet-stor/mgroepl/net_scratch/conda_envs/hFace/lib/python3.12/site-packages/transformers/trainer.py", line 2365, in _inner_training_loop
[rank1]:     model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
[rank1]:                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/itet-stor/mgroepl/net_scratch/conda_envs/hFace/lib/python3.12/site-packages/accelerate/accelerator.py", line 1389, in prepare
[rank1]:     result = tuple(
[rank1]:              ^^^^^^
[rank1]:   File "/itet-stor/mgroepl/net_scratch/conda_envs/hFace/lib/python3.12/site-packages/accelerate/accelerator.py", line 1390, in <genexpr>
[rank1]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank1]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/itet-stor/mgroepl/net_scratch/conda_envs/hFace/lib/python3.12/site-packages/accelerate/accelerator.py", line 1263, in _prepare_one
[rank1]:     return self.prepare_model(obj, device_placement=device_placement)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/itet-stor/mgroepl/net_scratch/conda_envs/hFace/lib/python3.12/site-packages/accelerate/accelerator.py", line 1522, in prepare_model
[rank1]:     model = torch.nn.parallel.DistributedDataParallel(
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/itet-stor/mgroepl/net_scratch/conda_envs/hFace/lib/python3.12/site-packages/torch/nn/parallel/distributed.py", line 825, in __init__
[rank1]:     _verify_param_shape_across_processes(self.process_group, parameters)
[rank1]:   File "/itet-stor/mgroepl/net_scratch/conda_envs/hFace/lib/python3.12/site-packages/torch/distributed/utils.py", line 288, in _verify_param_shape_across_processes
[rank1]:     return dist._verify_params_across_processes(process_group, tensors, logger)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: CUDA error: named symbol not found
[rank1]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

And my current batch file to run the multi gpu:

#!/bin/bash

#SBATCH --output=/home/mgroepl/log/%j.out     # where to store the output (%j is the JOBID), subdirectory "log" must exist
#SBATCH --error=/home/mgroepl/log/%j.out   # where to store error messages
#SBATCH --nodes=1                   # number of nodes
#SBATCH --ntasks-per-node=1         # number of MP tasks
#SBATCH --gres=gpu:2  
#SBATCH --constraint=ampere
# Load Conda (Important for Non-Interactive Shells)

source /itet-stor/mgroepl/net_scratch/conda/etc/profile.d/conda.sh
conda init bash


conda init bash
conda activate hFace

export PYTHONPATH=$PYTHONPATH:/itet-stor/mgroepl/net_scratch/trl
export HF_HOME=/itet-stor/mgroepl/net_scratch/hCache
export GPUS_PER_NODE=2

head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
######################

    
srun accelerate launch \
    --num_processes 1 \
    --num_machines $SLURM_NNODES \
    --rdzv_backend c10d \
    --main_process_ip $head_node_ip \
    --main_process_port 29500 \
    /home/mgroepl/hFace/test.py



echo "Running on node: $(hostname)"
echo "In directory:    $(pwd)"
echo "Starting on:     $(date)"
echo "SLURM_JOB_ID:    ${SLURM_JOB_ID}"



# Send more noteworthy information to the output log
echo "Finished at:     $(date)"

# End the script with exit code 0
exit 0

Would appreciate any help

1 Like

Seems ongoing issue…