How to do multi-GPU inference with ControlNet?

I found it difficult to do batch inference using StableDiffusionControlNetPipeline. I can not do accelerate launch and use a dataloader to load batches of images to the pipeline and generate images using multi-GPU. Can anyone help me to solve it?

def main(args):
    controlnet = ControlNetModel.from_pretrained(
        args.controlnet_path, torch_dtype=torch.float16
    )
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        args.base_model_path,
        controlnet=controlnet,
        torch_dtype=torch.float16,
        safety_checker=None,
    )
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

    # memory optimization.
    pipe.offload_unet()
    pipe.enable_model_cpu_offload()

    dataset = load_dataset(args.dataset_name, cache_dir=args.cache_dir)

    def collate_fn(examples):
        conditioning_image = [example["conditioning_image"] for example in examples]
        text = [example["text"] for example in examples]
        conditioning_image_id = [
            example["conditioning_image_id"] for example in examples
        ]

        return {
            "text": text,
            "conditioning_image": conditioning_image,
            "conditioning_image_id": conditioning_image_id,
        }

    dataset["train"] = dataset["train"].select(
        range(args.start_index, args.end_index + 1)
    )

    train_dataset = dataset["train"]
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=False,
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        num_workers=args.dataloader_num_workers,
        pin_memory=True,
    )
    generator = torch.Generator("cuda").manual_seed(args.seed)

    os.makedirs(args.output_dir, exist_ok=True)

    for _step, batch in enumerate(train_dataloader):
        prompt_batch = batch["text"]
        conditioning_image_batch = batch["conditioning_image"]
        conditioning_image_id_batch = batch["conditioning_image_id"]
        images = pipe(
            prompt=prompt_batch,
            image=conditioning_image_batch,
            generator=generator,
            num_inference_steps=args.inf_steps,
        ).images
        print(f"Saving images for step {_step} in {args.start_index}-{args.end_index}.")
        for img, id in zip(images, conditioning_image_id_batch):
            output_name = id.replace(".tif", ".png")
            img_path = os.path.join(args.output_dir, output_name)
            img.save(img_path, "PNG")