Optimizing my SDXL pipeline

I’m trying to speed up inference in my SDXL pipeline. Within my pipeline I’m using controlnet and IP adaptation. I’ve read through huggingface’s guides to optimizing and speeding up inference and I keep getting error after error. It feels like I’m making my way through a minefield where only one correct path exists, and I’m doing it without a map.

My approach is to throw spaghetti at the wall to see what sticks, slowly adding more and more to my pipeline while I figure out what’s wrong with the parts that aren’t working. Whenever I add quantization or torch.compile to my pipeline I get the error

“RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same”

Even when I have it set up exactly the same as it’s set up in the optimization guidelines it still gives me the error. This is the guide I’m following: Accelerate inference of text-to-image diffusion models

I’ll post my code below. Can someone please help me understand what I’m doing wrong?

import torch
from diffusers.models.attention_processor import AttnProcessor2_0
from torchao.quantization import swap_conv2d_1x1_to_linear, apply_dynamic_quant
import tomesd
from DeepCache import DeepCacheSDHelper

DTYPE = torch.bfloat16

def apply_pipeline_optimizations(pipe):
    pipe.torch_dtype = DTYPE
    # pipe.to(device="cuda", dtype=DTYPE, memory_format=torch.channels_last)
    pipe.to("cuda")

    # Inferrence Speed Optimizations
    pipe.enable_vae_slicing()
    pipe.enable_vae_tiling()
    pipe.enable_attention_slicing()

    #Apply token merging
    tomesd.apply_patch(pipe, ratio=0.5)

    # Apply DeepCache optimizations
    helper = DeepCacheSDHelper(pipe=pipe)
    helper.set_params(
        cache_interval=3,
        cache_branch_id=0,
    )
    helper.enable()

    # Memory Optimizations
    pipe.enable_model_cpu_offload() # -- This optimization slightly reduces memory consumption, but is optimized for speed.
    # pipe.enable_sequential_cpu_offload() # -- This optimization reduces memory consumption, but also reduces speed.
    pipe.enable_xformers_memory_efficient_attention() # -- useless for torch > 2.0, but if using torch < 2.0, this is an essential optimization.

    # Torch Optimizations
    print("[torch-pipeline] Applying base torch optimizations...")
    # pipe.unet.set_attn_processor(AttnProcessor2_0())
    pipe.fuse_qkv_projections()
    # pipe.unet.to(device="cuda", dtype=DTYPE, memory_format=torch.channels_last)
    # # pipe.vae.to(device="cuda", dtype=DTYPE, memory_format=torch.channels_last)
    pipe.unet.to(memory_format=torch.channels_last)
    pipe.vae.to(memory_format=torch.channels_last)

    # print("[torch-pipeline] Quantizing torch models...")
    # swap_conv2d_1x1_to_linear(pipe.unet, conv_filter_fn)
    # swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)
    # apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
    # apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)

    print("[torch-pipeline] Compiling torch models...")
    # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) # max-autotune or reduce-overhead
    # pipe.vae.decode = torch.compile(pipe.vae.decode, mode="reduce-overhead", fullgraph=True)
    pipe.upcast_vae()

# The following functions are from: https://huggingface.co/docs/diffusers/tutorials/fast_diffusion
def dynamic_quant_filter_fn(mod, *args):
    return (
        isinstance(mod, torch.nn.Linear)
        and mod.in_features > 16
        and (mod.in_features, mod.out_features)
        not in [
            (1280, 640),
            (1920, 1280),
            (1920, 640),
            (2048, 1280),
            (2048, 2560),
            (2560, 1280),
            (256, 128),
            (2816, 1280),
            (320, 640),
            (512, 1536),
            (512, 256),
            (512, 512),
            (640, 1280),
            (640, 1920),
            (640, 320),
            (640, 5120),
            (640, 640),
            (960, 320),
            (960, 640),
        ]
    )


def conv_filter_fn(mod, *args):
    return (
        isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels]
    )