I am using accelerate to perform multiGPU inference of openllama models (3b/13b). Both the models are able to do inference on a single GPU perfectly fine with a large batch size of 32. Since I have more than 1 GPU in my machine, I want to do parallel inference. For that, I used torch DDP and huggingface accelerate. Both of them crash with OOM eror for the 13b model and take 3X memory for the 3B model compared to when it is run on a single GPU. I understand that DDP might have some overhead (although I am only doing inference, no training), still 3X seems a bit odd. I have attached the minimal code to reproduce the error. The first few lines tell the exact command to run and the packages to be installed in the virtual environment. I have reproduced this issue in 3 separate machines. I am curious as to why this is happening and how can I resolve it. Thank you so much! Your help is very appreciated.
''''
Commands to run:
1. python mini_example.py --use_single_gpu
2. accelerate launch --multi_gpu --gpu_ids=0,1,2,3 --mixed_precision=fp16 --num_processes=2 --num_machines=1 --main_process_port=29500 mini_example.py --use_accelerate_1
3. python mini_example.py --use_ddp
Packages used:
pip install transformers
pip install torch
pip install accelerate
pip install sentencepiece
pip install protobuf
'''
import torch
import warnings
warnings.filterwarnings("ignore")
from transformers import LlamaTokenizer, LlamaForCausalLM
import os
model_path = 'openlm-research/open_llama_13b'
import subprocess
def get_gpu_memory_usage():
try:
result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'])
return list(map(int, result.decode('utf-8').strip().split('\n')))
except Exception as e:
print(f"Error in fetching GPU status: {e}")
return []
## This works with peak memory usage of 25Gb on a single GPU
def single_gpu(args):
args.device = "cuda:0"
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=args.device)
prompt = 'Q: What is the largest animal?\nA:'
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(args.device)
memory_usages = get_gpu_memory_usage()
print(f"Memory usages: {memory_usages}")
generation_output = model.generate(
input_ids=input_ids, max_new_tokens=101
)
print(tokenizer.decode(generation_output[0]))
## This crashes with OOM error
def multi_gpu_accelerate_load_to_memory():
from accelerate import Accelerator
accelerator = Accelerator()
model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
model = accelerator.prepare(model)
memory_usages = get_gpu_memory_usage()
print(f"Memory usages: {memory_usages}")
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def worker_ddp(rank, args):
args.device = "cuda:" + str(args.device_ids[rank])
# tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=args.device)
dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:5438', world_size=args.num_devices, rank=rank)
model = model.to(args.device)
torch.cuda.set_device(args.device)
model = DDP(model, device_ids=[args.device_ids[rank]])
memory_usages = get_gpu_memory_usage()
print(f"Memory usages: {memory_usages}")
## crashes with OOM error
def multi_gpu_ddp():
args.device_ids = list(map(int, args.gpu_ids))
args.num_devices = len(args.device_ids)
import torch.multiprocessing as mp
mp.set_start_method("spawn", force = True)
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"
import time
time.sleep(5)
print("LAUNCHING DDP ON", args.num_devices, "GPUs: ", args.device_ids)
mp.spawn(worker_ddp, nprocs=args.num_devices, args=(args,))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--use_single_gpu", action="store_true")
parser.add_argument("--use_accelerate", action="store_true")
parser.add_argument("--use_ddp", action="store_true")
parser.add_argument('--gpu_ids', type=str, nargs = "+", default=["0", "1", "2", "3"])
args = parser.parse_args()
if args.use_single_gpu:
single_gpu(args) ## Works perfectly fine
elif args.use_accelerate:
multi_gpu_accelerate_load_to_memory() ## Crashes with OOM error
elif args.use_ddp:
multi_gpu_ddp() ## Crashes with OOM error