Custom bert embedding cause "RuntimeError: Trying to backward through the graph a second time"

Hello, I’m working on a BertModel which need to finetune of ‘parts of’ its word embedding layer while keeping rest model layer’s weight frozen.
I wrote a custom Embedding layer inherit from nn.Embedding. As said before, I want keep part of word embedding weight frozen, so I modify the nn.Embedding.weight in forward function to new_weight.
as follow

class P_Embedding(nn.Embedding):
    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2, scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[Tensor] = None, device=None, dtype=None) -> None:
        super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, device, dtype)
        weight_sg = self.weight.detach().to(device)
        mask = torch.zeros(self.num_embeddings, 1).to(device)
        mask[0:10] += 1
        self.new_weight = self.weight * mask + weight_sg * (1-mask)
        # self.weight.set_(self.new_weight)
        # self.new_weight.retain_grad()
    def forward(self, input: Tensor) -> Tensor:
        return F.embedding(
            input, self.new_weight, self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse

and also I directly change bert model’s word embedding to my custom P_embedding and frozen the rest bert model weight like this.
(I dont know if this would be correct)

class BertForPtuningLM(BertForMaskedLM):
    def __init__(self, config):
        self.config = config
        self.bert = BertModel(config)
        self.bert.embeddings.word_embeddings = P_Embedding(config.vocab_size, config.hidden_size, device=device)
        # froze weight
        for name, param in self.bert.named_parameters():
            param.requires_grad = False
        self.bert.embeddings.word_embeddings.weight.requires_grad = True

I used Trainer from huggingface to train the model.
However, I get error when i run the code

Traceback (most recent call last):
  File "", line 306, in <module>
  File "", line 304, in main
  File "/home/lzw/miniconda3/envs/Bert/lib/python3.8/site-packages/transformers/", line 1543, in train
    return inner_training_loop(
  File "/home/lzw/miniconda3/envs/Bert/lib/python3.8/site-packages/transformers/", line 1791, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/lzw/miniconda3/envs/Bert/lib/python3.8/site-packages/transformers/", line 2557, in training_step
  File "/home/lzw/miniconda3/envs/Bert/lib/python3.8/site-packages/torch/", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/lzw/miniconda3/envs/Bert/lib/python3.8/site-packages/torch/autograd/", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I not sure whether the problem is from the new defined weight…