Loading txt2img LoRA after training leads to noise images

Iā€™ve got a slightly customized txt2img LoRA training script based on the diffusers example. The changes are largely irrelevant as far as I can tell (adding multi-aspect-ratio dataset/sampler, training with v_prediction and ztsnr).

The problem Iā€™m experiencing is that when training completes the final evaluation step (and subsequent loading of the LoRA weights) produces noisy images, unless I manually re-add the LoRA configs to the pipeline.

Here is a training image, and then the final ā€œtestā€ image if I donā€™t re-add the lora adapters after creating the pipeline object.

Hereā€™s the impacted code:

    # Save the lora layers
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        set_eval()

        unet = unwrap_model(unet)
        unet = unet.to(torch.float32)

        unet_lora_state_dict = convert_state_dict_to_diffusers(
            get_peft_model_state_dict(unet)
        )

        if args.train_text_encoder:
            text_encoder = unwrap_model(text_encoder)
            text_encoder_state_dict = convert_state_dict_to_diffusers(
                get_peft_model_state_dict(text_encoder)
            )
        else:
            text_encoder_state_dict = None

        StableDiffusionPipeline.save_lora_weights(
            save_directory=args.output_dir,
            unet_lora_layers=unet_lora_state_dict,
            text_encoder_lora_layers=text_encoder_state_dict,
        )

        # Final inference
        # Load previous pipeline
        pipeline = DiffusionPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            revision=args.revision,
            variant=args.variant,
            scheduler=DDPMScheduler.from_pretrained(
                args.pretrained_model_name_or_path,
                subfolder="scheduler",
                prediction_type="v_prediction",
                rescale_betas_zero_snr=True,
                timestep_spacing="trailing",
            ),
            torch_dtype=weight_dtype,
        )

        # Without adding the two LoraConfig's here, the outputs are corrupted.
        unet_lora_config = LoraConfig(
            r=args.rank,
            lora_alpha=args.rank,
            init_lora_weights="pissa",
            target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        )
        pipeline.unet.add_adapter(unet_lora_config)
        if args.train_text_encoder:
            text_lora_config = LoraConfig(
                r=args.rank,
                lora_alpha=args.rank,
                init_lora_weights="pissa",
                target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
            )
            pipeline.text_encoder.add_adapter(text_lora_config)

        # Load the Lora weights
        pipeline.load_lora_weights(
            args.output_dir, weight_name="pytorch_lora_weights.safetensors"
        )

        # run inference
        images = log_validation(
            pipeline, args, accelerator, epoch, is_final_validation=True
        )

The most interesting part is that when I add the adapters again as above, the peft library outputs a warning:

08/22/2024 09:14:33 - INFO - peft.tuners.tuners_utils - Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!
08/22/2024 09:14:40 - INFO - peft.tuners.tuners_utils - Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!
08/22/2024 09:14:47 - INFO - __main__ - Running validation... 

But without doing it and seeing the warning, my outputs are garbage.

Does anyone have any insights about why this might happen? Is this a bug I should open an issue for, or am I missing something obvious?

Related package versions:

accelerate==0.24.1
diffusers==0.30.0
peft==0.12.0
tokenizers==0.19.1
torch==2.3.0
transformers==4.41.2

This was caused by two things:

  • the LoRA examples in diffusers are not all up-to-date, I copied from an SD example, but it seems only the SDXL examples get appropriate love
  • there was still a slight discrepency in outputs after fixing the saving/loading of LoRA weights caused by not upcasting the LoRA weights properly when saving them at the end of training.

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.