Building Custom AutoModelForTask

Hi, I want to build something like a custom AutoModelForxx class. For example, I know that you can create a custom class for a specific model like this:

class BertForEmbedding(BertPreTrainedModel):
    config_class = BertConfig

    def __init__(self, config, embed_pooling: str = "cls"):
        super().__init__(config)
        self.bert = BertModel(config=config)
        self.embed_pooling = embed_pooling
        self.init_weights()

    def forward(
        self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None
    ):
        output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        return self._pooling(output, attention_mask)
        
    def _pooling(self, output, attention_mask):
        if self.embed_pooling == "mean":
            return self._mean_pooling(output, attention_mask)
        elif self.embed_pooling == "cls":
            return self._cls_pooling(output)

    def _mean_pooling(self, output, attention_mask):
        token_embeddings = output.last_hidden_state
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        )
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
            input_mask_expanded.sum(1), min=1e-9
        )

    def _cls_pooling(self, output):
        return output.last_hidden_state[:, 0, :]

I wanted to do the same thing, but for a general model. Here’s my attempt:

class AutoModelForEmbedding(PreTrainedModel):
    config_class = AutoConfig

    def __init__(self, config, pooler: Optional[Pooler] = None):
        super().__init__(config)
        if pooler is None:
            self.pooler = CLSPooler()

        self.model = AutoModel.from_config(config.model_name_or_path)
        self.init_weight()

    def _init_weights(self, module):
        self.model._init_weights(module)

    def forward(self, **kwargs):
        output = self.model(**kwargs)
        return self.pooler(output, **kwargs)

The problem is that the weights are not properly initialized with the pretrained ones. Does anyone know how to properly do it?