Custom Safety checker class

Hi! I’m currently working on an adaptation of the Safety Checker pipeline. I want to have something similar to the one already working but to be able to customize the concepts, and solely output the cosine distance to the concept. Simple right? But i’m failing to understand how to load the weights because where is currently failing is that the vision_model seems to not load the weights from the model that im passing via parameters.

This is how i’m trying to initialize the class

concepts = ['sexual', 'nude', 'sex', '18+', 'naked', 'nsfw', 'porn', 'dick', 'vagina', 'naked child', 'explicit content', 'uncensored', 'fuck', 'nipples', 'visible nipples', 'naked breasts', 'areola']
safety_checker_custom = StableDiffusionSafetyChecker.from_pretrained(pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4", concepts=concepts, subfolder="safety_checker", revision="fp16", torch_dtype=torch.float16).to("cuda")

And this is the code of the class. In the class method from_pretrained i guess that i should load the weights of the models passed by parameter, but i don’t know how to do it. I searched but i didn’t found any HF help website on how to implement custom clases

class StableDiffusionSafetyChecker(PreTrainedModel):
    config_class = CLIPConfig

    _no_split_modules = ["CLIPEncoderLayer"]

    def __init__(self, config: CLIPConfig, concepts, revision=None, torch_dtype=None):
        super().__init__(config)


        self.vision_model = CLIPVisionModel(config.vision_config)
        self.visual_projection = nn.Linear(
            config.vision_config.hidden_size, config.projection_dim, bias=False
        )

         model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
         tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")

        inputs = tokenizer(concepts, padding=True, return_tensors="pt")
        text_features = model.get_text_features(**inputs)
        self.concept_embeds = nn.Parameter(text_features, requires_grad=False)
        self.concepts = concepts

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, concepts, **kwargs):
        config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
        model = cls(config, concepts)

        # Load the model weights
        # model.load_state_dict(pretrained_model_name_or_path)

        return model


    @torch.no_grad()
    def forward(self, clip_input, images):
        pooled_output = self.vision_model(clip_input)[1]  # pooled_output
        image_embeds = self.visual_projection(pooled_output)

        cos_dist = (
            cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
        )

        result = []
        batch_size = image_embeds.shape[0]
        for i in range(batch_size):
            results = {}

            for concept_idx in range(len(cos_dist[0])):
                concept_cos = cos_dist[i][concept_idx]
                results[self.concepts[concept_idx]] = concept_cos            

        return results