Hello I am loading a Unet model and adding additional class-embedding layers like this:
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder=“unet”, revision=args.revision
)
camera_embedding = ToWeightsDType(TimestepEmbedding(16, 1280), weight_dtype)
unet.class_embedding = camera_embedding
These are my saving and loading methods:
def save_model_hook(models, weights, output_dir):
for model in models:
model_type = str(type(model))
print(f'model type: {model_type}')
# Check for substring matches in the model type
if 'UNet' in model_type:
sub_dir = "unet"
elif 'CLIPTextModel' in model_type:
sub_dir = "text_encoder"
else:
print(f'Unknown model type: {model_type}. Skipping save.')
continue
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()
model_type = str(type(model))
if 'CLIPTextModel' in model_type:
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(
input_dir, subfolder="text_encoder")
model.config = load_model.config
elif 'UNet' in model_type:
# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(
input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
else:
print(f'Unknown model type: {model_type}. Skipping load.')
continue
model.load_state_dict(load_model.state_dict())
del load_model
But when I try to load the saved model I get this error:
RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel:
Missing key(s) in state_dict: “class_embedding.module.linear_1.weight”, “class_embedding.module.linear_1.bias”, “class_embedding.module.linear_2.weight”, “class_embedding.module.linear_2.bias”.
I think the config is not modified during the save.
Please let me know how to resolve this issue.
Thank you.