Custom pipeline for image inpainting

Hi, I am new to Stable Dffusion. Currently, I would like to create a custom image inpainting pipeline. I have a problem, because the image beside the mask is being changed and the results on mask area are not getting better. Could you suggest what should I correct?

import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
import os
import numpy as np
import torch.nn.functional as F

vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-inpainting', subfolder='vae')
unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-inpainting', subfolder='unet')
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14')

scheduler = LMSDiscreteScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    num_train_timesteps=1000,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
vae.to(device)
unet.to(device)
text_encoder.to(device)

output_dir = 'inpainting_steps'
os.makedirs(output_dir, exist_ok=True)

def save_image_from_latents(latents, step):
    latents = 1 / 0.18215 * latents
    with torch.no_grad():
        decoded_output = vae.decode(latents)
        decoded_image = decoded_output.sample
    decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
    output_image = (decoded_image * 255).round().astype("uint8")[0]
    output_pil_image = Image.fromarray(output_image)
    output_pil_image.save(os.path.join(output_dir, f"step_{step:03}.png"))

def encode_prompt(prompt, tokenizer, text_encoder):
    text_input = tokenizer(prompt, return_tensors="pt", max_length=tokenizer.model_max_length, padding="max_length", truncation=True)
    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
    return text_embeddings

def inpaint(image, mask_image, prompt, vae, unet, tokenizer, text_encoder, scheduler, num_inference_steps=50, guidance_scale=7.5):
    if image.shape[-1] == 4:
        image = image[:, :, :3]
    image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).to(device).float()
    mask_tensor = torch.from_numpy(mask_image).unsqueeze(0).unsqueeze(0).to(device).float()
    if mask_tensor.max() > 1.0:
        mask_tensor = mask_tensor / 255.0
    text_embeddings = encode_prompt(prompt, tokenizer, text_encoder)
    if guidance_scale > 1.0:
        uncond_input = tokenizer("", return_tensors="pt", max_length=tokenizer.model_max_length, padding="max_length", truncation=True)
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    latents = vae.encode(image_tensor).latent_dist.sample()
    latents = latents * 0.18215
    mask_tensor_resized = F.interpolate(mask_tensor, size=(latents.shape[-2], latents.shape[-1]), mode="nearest")
    masked_image_tensor = image_tensor * (1 - mask_tensor)
    if masked_image_tensor.shape[1] == 3:
        masked_image_tensor = torch.cat([masked_image_tensor, torch.zeros_like(masked_image_tensor[:, :1, :, :])], dim=1)
    masked_image_tensor_resized = F.interpolate(masked_image_tensor, size=(latents.shape[-2], latents.shape[-1]), mode="bilinear", align_corners=False)
    generator = torch.manual_seed(1234)
    noise = torch.randn(latents.shape, generator=generator).to(device)
    latents = latents * (1 - mask_tensor_resized) + noise * mask_tensor_resized
    scheduler.set_timesteps(num_inference_steps)
    with torch.no_grad():
        for i, t in enumerate(scheduler.timesteps):
            latent_input = scheduler.scale_model_input(latents, t)
            latent_input = torch.cat([latent_input, mask_tensor_resized, masked_image_tensor_resized], dim=1)
            if guidance_scale > 1.0:
                latent_input = torch.cat([latent_input] * 2)
                noise_pred = unet(latent_input, t, encoder_hidden_states=text_embeddings).sample
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
            else:
                noise_pred = unet(latent_input, t, encoder_hidden_states=text_embeddings).sample
            latents = scheduler.step(noise_pred, t, latents).prev_sample
            save_image_from_latents(latents, i)
    generated_image = vae.decode(latents / 0.18215).sample()
    final_image = generated_image * mask_tensor + image_tensor * (1 - mask_tensor)
    final_image = (final_image / 2 + 0.5).clamp(0, 1)
    final_image = final_image.detach().cpu().permute(0, 2, 3, 1).squeeze().numpy()
    final_image = (final_image * 255).astype(np.uint8)
    return final_image

image = np.array(Image.open("4.png"))
mask_image = np.array(Image.open("5.png").convert("L"))
prompt = "match the rest of the old fresco"
final_image = inpaint(image, mask_image, prompt, vae, unet, tokenizer, text_encoder, scheduler)

1 Like

runwayml/stable-diffusion-inpainting
Not the essential part of this question, but I think this repo is gone.
I don’t actually know the ins and outs of what exactly happened, but the whole thing is gone.

vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-inpainting', subfolder='vae')
unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-inpainting', subfolder='unet')

https://huggingface.co/runwayml/stable-diffusion-inpainting