Inference on Multi-GPU/multinode

You can try to utilize accelerate.
In this link you can see how to modify a code similar to yours in order to integrate the accelerate library, which can take care of the distributed setup for you.
I didn’t work with it directly for long so I might forget the specific details. Like whether you need to pass it your nodes/GPUs or not, if you do how to, but I’m sure you can easily find all those details :slight_smile:

By the way, I just came across this recent post, which might also come handy to your needs