Error occurs when loading additional parameters in multi-gpu training

I’m training plugins (e.g. adapter) on top of the language model on multiple GPUs using huggingface Accelerate. However, strange things occur when I try to load additional parameters into the model. The training cannot move on successfully and I find the process state not as expected.

My running command is like this:

 CUDA_VISIBLE_DEVICES=1,2,3,4 python -m torch.distributed.launch --nproc_per_node 4 --use_env ./

But I found the process state like this:

I don’t know why there will be three processes on GPU1. This is definitely not correct.

The code works well if I use a single GPU or if I don’t load the additional parameters. So I think there is no problem locating the bug.

FYI, the code related to loading additional parameters is as follows:

 model = MyRobertaModel.from_pretrained(

    accelerator.wait_for_everyone()  # I try to add barrier but it doesn't solve my problem
    t = args.t
    if t > 0:
        embed_pool = torch.load(os.path.join(args.saved_plugin_dir, 'embed_pool.pth'))
        for i in range(t):
            model.add_prompt_embedding(t=i, saved_embedding=embed_pool[i])
        plugin_ckpt = torch.load(os.path.join(args.saved_plugin_dir, 'plugin_ckpt.pth'))



The code for loading plugin looks like this

    def load_plugin(self, plugin_ckpt):
        idx = 0
        for name, sub_module in super().named_modules():
            if isinstance(sub_module, MyAdapter):
                idx += 1

        print('Load plugins successfully!')

Also, my library versions are:
python 3.6.8
transformers 4.11.3
accelerate 0.5.1
NVIDIA gpu cluster

Really thank you for your help!

Turn out to be a stupid mistake by me.

embed_pool = torch.load(os.path.join(args.saved_plugin_dir, 'embed_pool.pth'))

should be changed to

embed_pool = torch.load(os.path.join(args.saved_plugin_dir, 'embed_pool.pth'), map_location=torch.device('cpu'))

torch.load() will automatically map the file to device:0, and this device is the same in the eye of each device, thus causing the problem (spawning additional processes in device:0).

Mark this problem as solved by myself.