I'm having trouble saving and loading the model, the state dictionary doesn't correspond

I saved my model for fine-tuning, and I encountered a problem when loading the model in the saved state dictionary, all other structures in the network were fine, only one network layer reported an error, and the saved key of that layer did not correspond.

Here is the warning message:
Some weights of the model checkpoint at best_models/deberta-v2-chinese-gp-gccn were not used when initializing CAILNet: [‘cln.bias’, ‘cln.bias_dense.weight’, ‘cln.weight’, ‘cln.weight_dense.weight’]
Some weights of CAILNet were not initialized from the model checkpoint at best_models/deberta-v2-chinese-gp-gccn and are newly initialized: [‘cln.beta’, ‘cln.beta_dense.weight’, ‘cln.gamma’, ‘cln.gamma_dense.weight’]
After debugging, I found that the key in the model structure is [‘cln.beta’, ‘cln.beta_dense.weight’, ‘cln.gamma’, ‘cln.gamma_dense.weight’].
The saved key is [‘cln.bias’, ‘cln.bias_dense.weight’, ‘cln.weight’, ‘cln.weight_dense.weight’].
This leads to inconsistent loading of the model, but I’m pretty sure I never changed the code and structure of this network layer, and my evaluation tests are fine after the training is done, only reloading the model is a problem.

class ConLayerNorm(nn.Module):
    def __init__(self, input_dim, cond_dim=0, center=True, scale=True, epsilon=None, conditional=False,
                 hidden_units=None, hidden_activation='linear', hidden_initializer='xaiver'):
        super().__init__()
        """
        input_dim: inputs.shape[-1]
        cond_dim: cond.shape[-1]
        """
        self.center = center
        self.scale = scale
        self.conditional = conditional
        self.hidden_units = hidden_units
        self.hidden_initializer = hidden_initializer
        self.epsilon = epsilon or 1e-12
        self.input_dim = input_dim
        self.cond_dim = cond_dim

        if self.center:
            self.beta = nn.Parameter(torch.zeros(input_dim))
        if self.scale:
            self.gamma = nn.Parameter(torch.ones(input_dim))

        if self.conditional:
            if self.hidden_units is not None:
                self.hidden_dense = nn.Linear(in_features=self.cond_dim, out_features=self.hidden_units, bias=False)
            if self.center:
                self.beta_dense = nn.Linear(in_features=self.cond_dim, out_features=input_dim, bias=False)
            if self.scale:
                self.gamma_dense = nn.Linear(in_features=self.cond_dim, out_features=input_dim, bias=False)

        self.initialize_weights()

    def initialize_weights(self):

        if self.conditional:
            if self.hidden_units is not None:
                if self.hidden_initializer == 'normal':
                    torch.nn.init.normal(self.hidden_dense.weight)
                elif self.hidden_initializer == 'xavier':  # glorot_uniform
                    torch.nn.init.xavier_uniform_(self.hidden_dense.weight)

            if self.center:
                torch.nn.init.constant_(self.beta_dense.weight, 0)
            if self.scale:
                torch.nn.init.constant_(self.gamma_dense.weight, 0)

    def forward(self, inputs, cond=None):
        if self.conditional:
            if self.hidden_units is not None:
                cond = self.hidden_dense(cond)

            for _ in range(len(inputs.shape) - len(cond.shape)):
                cond = cond.unsqueeze(1)  # cond = K.expand_dims(cond, 1)

            if self.center:
                beta = self.beta_dense(cond) + self.beta
            if self.scale:
                gamma = self.gamma_dense(cond) + self.gamma
        else:
            if self.center:
                beta = self.beta
            if self.scale:
                gamma = self.gamma

        outputs = inputs
        if self.center:
            mean = torch.mean(outputs, dim=-1).unsqueeze(-1)
            outputs = outputs - mean
        if self.scale:
            variance = torch.mean(outputs ** 2, dim=-1).unsqueeze(-1)
            std = (variance + self.epsilon) ** 0.5
            outputs = outputs / std
            outputs = outputs * gamma
        if self.center:
            outputs = outputs + beta

        return outputs

The problem seems to be that in the method of loading the model weights, the tansformers are loaded with beta modified to bias and gamma modified to weight

        def _fix_key(key):
            if "beta" in key:
                return key.replace("beta", "bias")
            if "gamma" in key:
                return key.replace("gamma", "weight")
            return key

        original_loaded_keys = loaded_keys
        loaded_keys = [_fix_key(key) for key in loaded_keys]

But what elegant solution is there to this?