Fine-tuning ControlNet-XS with SDXL

I’m trying to fine-tune ControlNet-XS by adapting the ControlNet SDXL script in diffusers. I have the following piece of code around line 1205 where the error occur:

controlnet = ControlNetXSAdapter.from_unet(unet, block_out_channels=[320, 640, 1280])

controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
down_block_res_samples, mid_block_res_sample = controlnet(
    noisy_latents,
    timesteps,
    encoder_hidden_states=batch["prompt_ids"],
    added_cond_kwargs=batch["unet_added_conditions"],
    controlnet_cond=controlnet_image,
    return_dict=False,
)

But I’m getting the following ValueError:

ValueError: A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel.

I’ve unsuccessfully tried to solve it by instantiating a UNetControlNetXSModel before calling controlnet above:

controlnet = UNetControlNetXSModel(unet, controlnet)

Does anyone have any ideas on how to solve this?