Hi, I am new to Stable Dffusion. Currently, I would like to create a custom image inpainting pipeline. I have a problem, because the image beside the mask is being changed and the results on mask area are not getting better. Could you suggest what should I correct?
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
import os
import numpy as np
import torch.nn.functional as F
vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-inpainting', subfolder='vae')
unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-inpainting', subfolder='unet')
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14')
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
vae.to(device)
unet.to(device)
text_encoder.to(device)
output_dir = 'inpainting_steps'
os.makedirs(output_dir, exist_ok=True)
def save_image_from_latents(latents, step):
latents = 1 / 0.18215 * latents
with torch.no_grad():
decoded_output = vae.decode(latents)
decoded_image = decoded_output.sample
decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
output_image = (decoded_image * 255).round().astype("uint8")[0]
output_pil_image = Image.fromarray(output_image)
output_pil_image.save(os.path.join(output_dir, f"step_{step:03}.png"))
def encode_prompt(prompt, tokenizer, text_encoder):
text_input = tokenizer(prompt, return_tensors="pt", max_length=tokenizer.model_max_length, padding="max_length", truncation=True)
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
return text_embeddings
def inpaint(image, mask_image, prompt, vae, unet, tokenizer, text_encoder, scheduler, num_inference_steps=50, guidance_scale=7.5):
if image.shape[-1] == 4:
image = image[:, :, :3]
image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).to(device).float()
mask_tensor = torch.from_numpy(mask_image).unsqueeze(0).unsqueeze(0).to(device).float()
if mask_tensor.max() > 1.0:
mask_tensor = mask_tensor / 255.0
text_embeddings = encode_prompt(prompt, tokenizer, text_encoder)
if guidance_scale > 1.0:
uncond_input = tokenizer("", return_tensors="pt", max_length=tokenizer.model_max_length, padding="max_length", truncation=True)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = vae.encode(image_tensor).latent_dist.sample()
latents = latents * 0.18215
mask_tensor_resized = F.interpolate(mask_tensor, size=(latents.shape[-2], latents.shape[-1]), mode="nearest")
masked_image_tensor = image_tensor * (1 - mask_tensor)
if masked_image_tensor.shape[1] == 3:
masked_image_tensor = torch.cat([masked_image_tensor, torch.zeros_like(masked_image_tensor[:, :1, :, :])], dim=1)
masked_image_tensor_resized = F.interpolate(masked_image_tensor, size=(latents.shape[-2], latents.shape[-1]), mode="bilinear", align_corners=False)
generator = torch.manual_seed(1234)
noise = torch.randn(latents.shape, generator=generator).to(device)
latents = latents * (1 - mask_tensor_resized) + noise * mask_tensor_resized
scheduler.set_timesteps(num_inference_steps)
with torch.no_grad():
for i, t in enumerate(scheduler.timesteps):
latent_input = scheduler.scale_model_input(latents, t)
latent_input = torch.cat([latent_input, mask_tensor_resized, masked_image_tensor_resized], dim=1)
if guidance_scale > 1.0:
latent_input = torch.cat([latent_input] * 2)
noise_pred = unet(latent_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred = unet(latent_input, t, encoder_hidden_states=text_embeddings).sample
latents = scheduler.step(noise_pred, t, latents).prev_sample
save_image_from_latents(latents, i)
generated_image = vae.decode(latents / 0.18215).sample()
final_image = generated_image * mask_tensor + image_tensor * (1 - mask_tensor)
final_image = (final_image / 2 + 0.5).clamp(0, 1)
final_image = final_image.detach().cpu().permute(0, 2, 3, 1).squeeze().numpy()
final_image = (final_image * 255).astype(np.uint8)
return final_image
image = np.array(Image.open("4.png"))
mask_image = np.array(Image.open("5.png").convert("L"))
prompt = "match the rest of the old fresco"
final_image = inpaint(image, mask_image, prompt, vae, unet, tokenizer, text_encoder, scheduler)