Data Parallel Multi GPU Inference

Found the following statement:
You don’t need to prepare a model if it is used only for inference without any kind of mixed precision
accelerate.Accelerator.prepare() documentation: Accelerator

In data-parallel multi-gpu inference, we want a model copy to reside on each GPU. How can we achieve that without passing the model through prepare() ?

You just move the model to the device. Check out the new distributed inference tutorial, and install accelerate from dev to make use of the new API if you want to do split_by_processes. Otherwise pass your dataloader to Accelerator.prepare and do

Using the DDP wrapper on your model is only relevant when you want to update the gradients (that’s what it’s designed there for), so inference just load the model on the device normally

Thanks @muellerzr for your reply. Is there any benefit in using split_between_processes() over accelerate.Accelerator().prepare() on a dataloader?

It’s useful if you don’t want to make a DataLoader, or have things that can’t go in there easily (like prompts in that example)

Understood, thanks @muellerzr !

Hi @muellerzr, is there a way I can do distributed inference using model sharding (FSDP) ?

Is there a reason you want to do so instead of using device_map/big model inference? This can help narrow down my recommendation

@muellerzr My model size is very close to the total GPU memory and from what I understood in this article, I cannot run batches in parallel on all GPUs if I use device_map="auto".

I was wondering if it’s possible to do inference in FSDP style, i.e. the model layers get sharded across all GPUs and layers get exchanged on demand so that I can process batches in parallel ?

Is there an end to end example (in the Example Zoo) that I can refer to?

@sgugger @muellerzr Any thoughts on this?

Hello @varadhbhatnagar, you can use FSDP for distributed inference as long as you aren’t using the generate method as FSDP is incompatible with generate (mentioned in the docs here: Fully Sharded Data Parallel (

For example, the docs here show it Fully Sharded Data Parallel ( The example also computes metrics on eval set which should mimic the distributed inference.

1 Like