I found it difficult to do batch inference using StableDiffusionControlNetPipeline
. I can not do accelerate launch
and use a dataloader to load batches of images to the pipeline and generate images using multi-GPU. Can anyone help me to solve it?
def main(args):
controlnet = ControlNetModel.from_pretrained(
args.controlnet_path, torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
args.base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
safety_checker=None,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# memory optimization.
pipe.offload_unet()
pipe.enable_model_cpu_offload()
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_dir)
def collate_fn(examples):
conditioning_image = [example["conditioning_image"] for example in examples]
text = [example["text"] for example in examples]
conditioning_image_id = [
example["conditioning_image_id"] for example in examples
]
return {
"text": text,
"conditioning_image": conditioning_image,
"conditioning_image_id": conditioning_image_id,
}
dataset["train"] = dataset["train"].select(
range(args.start_index, args.end_index + 1)
)
train_dataset = dataset["train"]
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=False,
collate_fn=collate_fn,
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers,
pin_memory=True,
)
generator = torch.Generator("cuda").manual_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
for _step, batch in enumerate(train_dataloader):
prompt_batch = batch["text"]
conditioning_image_batch = batch["conditioning_image"]
conditioning_image_id_batch = batch["conditioning_image_id"]
images = pipe(
prompt=prompt_batch,
image=conditioning_image_batch,
generator=generator,
num_inference_steps=args.inf_steps,
).images
print(f"Saving images for step {_step} in {args.start_index}-{args.end_index}.")
for img, id in zip(images, conditioning_image_id_batch):
output_name = id.replace(".tif", ".png")
img_path = os.path.join(args.output_dir, output_name)
img.save(img_path, "PNG")