OOM when I using torch.nn.parallel.DistributedDataParallel to train LLAMA-7B

I intended to use 4 NVIDIA 3090 GPUs to train LLAMA-7B(float16) with DistributedDataParallel ,but OOM occured. The llama-7b occupied around 15G memory in one gpu(24G in total), but once I called the function torch.nn.parallel.DistributedDataParallel(), the error happens.

-----the code is like below---------
model.to(device) #running ok here ,GPU memory is 15G/24G
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) #Out of memory
happened here

I would like to know why this happens after calling the DDP function and why memory is reallocated again? Does anyone know? I’m very grateful for that.