Saving and loading modified Unet

Hello I am loading a Unet model and adding additional class-embedding layers like this:
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder=“unet”, revision=args.revision

camera_embedding = ToWeightsDType(TimestepEmbedding(16, 1280), weight_dtype)
unet.class_embedding = camera_embedding

These are my saving and loading methods:
def save_model_hook(models, weights, output_dir):
for model in models:
model_type = str(type(model))

            print(f'model type: {model_type}')
            # Check for substring matches in the model type
            if 'UNet' in model_type:
                sub_dir = "unet"
            elif 'CLIPTextModel' in model_type:
                sub_dir = "text_encoder"
                print(f'Unknown model type: {model_type}. Skipping save.')
            model.save_pretrained(os.path.join(output_dir, sub_dir))

            # make sure to pop weight so that corresponding model is not saved again

    def load_model_hook(models, input_dir):
        while len(models) > 0:
            # pop models so that they are not loaded again
            model = models.pop()
            model_type = str(type(model))

            if 'CLIPTextModel' in model_type:
                # load transformers style into model
                load_model = text_encoder_cls.from_pretrained(
                    input_dir, subfolder="text_encoder")
                model.config = load_model.config
            elif 'UNet' in model_type:
                # load diffusers style into model
                load_model = UNet2DConditionModel.from_pretrained(
                    input_dir, subfolder="unet")
                print(f'Unknown model type: {model_type}. Skipping load.')

            del load_model

But when I try to load the saved model I get this error:

RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel:
Missing key(s) in state_dict: “class_embedding.module.linear_1.weight”, “class_embedding.module.linear_1.bias”, “class_embedding.module.linear_2.weight”, “class_embedding.module.linear_2.bias”.

I think the config is not modified during the save.
Please let me know how to resolve this issue.

Thank you.

Solved it by
unet = UNet2DConditionModel.from_pretrained(model, subfolder=“unet”, class_embed_type = “projection”,projection_class_embeddings_input_dim=16,revision=None, low_cpu_mem_usage=False, device_map=None)