How to share weights with multple encoders

I have multiple LSTM encoders like this:

class PromptEncoder(torch.nn.Module):
    def __init__(self,name, length,embedding_dim,id_offset, init_embs, prompt_ids,**kwargs) -> None:
        super().__init__()
        self.length = length
        self.name = name
        self.prompt_ids = prompt_ids
        self.input_ids = torch.nn.parameter.Parameter(torch.tensor(prompt_ids),
             requires_grad=False)
        self.embedding_dim = embedding_dim
        self.id_offset = id_offset
        self.embedding = torch.nn.Embedding(length,embedding_dim)
        self.net_inps = torch.nn.parameter.Parameter(torch.arange(length),
            requires_grad=False)
        self.lstm = torch.nn.LSTM(
            input_size=embedding_dim,
            hidden_size=embedding_dim //2, #my code
            num_layers=2,
            dropout=0,
            bidirectional=True,
            batch_first=True
        )
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, embedding_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(embedding_dim, embedding_dim)
        )
    def forward(self,prompt_token_ids,pids=None):
        # create embedding vectors for input ids
        embeds = self.embedding(self.net_inps)
        # do forward calculations
        x = self.lstm(embeds.unsqueeze(0))
        emblog.info("lstml embeds: %s",embeds)

        running_weight = self.mlp(x[0]).squeeze(0)

        prompt_token_ids = (prompt_token_ids.view(-1,1) == self.input_ids).int().argmax(dim=1)
        # return weights for prompt_token_ids 
        return F.embedding(prompt_token_ids,running_weight)


some of prompt_ids or input_ids could be shared among them. However each encoder has its own embedding matrix. They are all part of a container Module and are learned together. I want the shared ids point to a shared embedding so that if one changes, the change reflects to the embedding of all.
This is forward wrapper:

    def forward(self,input_ids, labels, decoder_input_ids=None,pids=None,**kwargs):
      prompt_masks = self.prompt_token_fn(input_ids)
        if prompt_masks.any():
            input_ids_ = input_ids.clone()
            if self.replacing_token_id is not None:
                # replace prompt ids in input_ids with replacing token
                input_ids_[prompt_masks]=self.replacing_token_id
            # find the model embeddings of input ids except for prompt tokens
            inputs_embeds = self.model_embeddings(input_ids_)
            device=inputs_embeds.device
            for encoder in self.prompt_encoders:
                #encoder = self.prompt_encoders[0]
                prompt_token_fn = encoder.get_prompt_token_fn()
                encoder_masks = prompt_token_fn(input_ids)
                if encoder_masks.any():
                    #find input ids for prompt tokens
                    prompt_input_ids = input_ids[encoder_masks]
                    # call forwards on prompt encoder whose outputs are prompt embeddings
                    prompt_embeds = encoder(prompt_input_ids,\
                        pids).to(device)
                    # replace prompt_embeddings calculated by prompt encoder in input embeddings
                    # in input embeds replace embeddings for prompt token with output of encoder
                    inputs_embeds[encoder_masks]=prompt_embeds
        else:
            inputs_embeds = self.model_embeddings(input_ids)
        
            return self.underlying_model(inputs_embeds=inputs_embeds,**kwargs)

The encoder are in a ModuleList. I put the code how they are called in the forward of the container Module. The container module actually wrap a transformer model (T5) which is freezed and the result of forward pass on encoders are fed into it. I am someway beginner with Pytorch and Transformer. For example some parts of my LSTM enoders such as having both input_id parameters and net_ids might be redundant?

Suppose input ids for whole model are [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, ....] in batch. the ids of encoder 1 are [3, 4] and the ids for encoder 2 are [1,3, 5]. where 3 is common between them. Each encoder must update embedding for its corresponding input. However, 3 exists in both of them. Maybe I should merge the results for 3 in the loop in forward, I don’t know. I thought maybe they must refer to a shared embedding space. Please guide me for this problem.