Hi team! I am trying to speed up inference using a diffusion model.
Consider this (slightly modified) example from https://huggingface.co/docs/diffusers/en/training/distributed_inference:
import torch
import datasets
from accelerate import PartialState
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
)
distributed_state = PartialState()
pipeline.to(distributed_state.device)
ds = datasets.load_from_disk("/path/to/my/dataset")
with distributed_state.split_between_processes(ds) as proc_ds:
proc_ds = proc_ds.map(lambda e: {
"image": pipeline(e["prompt"]).images[0]
})
How can I now combine the proc_ds
datasets from across processes into one global dataset again that I continue to work on from the main process?
Can I just declare a list results
outside of the with
block, and then, after the map
within the with block
, do results.append(proc_ds)
? Is that thread-safe? How does this actually work, how will that global list end up in the main process?
Sorry if this is obvious. Thank you for your help!