Soft prompt learning for BERT and GPT using Transformers


Does the Transformers library have an easy way to only finetune the embeddings of select few tokens in a Transformer model? (For example: the [unused1] [unused2] [unused3] tokens).

I want to try to generate “soft prompts” without updating the entire embedding layer of the Transformer.

I suppose that changing the model myself is always an option, but I wonder that the easiest way to achieve this would be, if this is the best option. I would like to be able to use this on different types of models (e.g. BERT and GPT), if this is possible.


@FremyCompany Any updates on this? I am having the same inquiry!

@niklasstoehr Hi! As a matter of fact, I have written code that can achieve this, but it’s not 100% usable because I didn’t move along very far in that specific experiment. Feel free to use it as a base for your work, though :slight_smile:

import torch
import torch.nn as nn

class PartialOverrideEmbedding(nn.Module):
    def __init__(self, 
                wte: nn.Embedding,
                start_override: int = 110, # [unused100] for my transformer
                length_override: int = 800, # [unused900] for my transformer
                initialize_from_vocab: bool = True):
        """appends learned embedding to 
            wte (nn.Embedding): original transformer word embedding
            start_override (int, optional): first token id which will be trained separately. Defaults to 110 ([unused100] for BERT).
            length_override (int, optional): how many tokens are to be trained separately after the first. Defaults to 800 ([unused900] for BERT).
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        super(PartialOverrideEmbedding, self).__init__()
        self.start_override = start_override
        self.length_override = length_override
        self.wte = wte
        self.wte_override = nn.Embedding(
            length_override, wte.weight.shape[1]
        if initialize_from_vocab:
            self.wte_override.weight[:] = self.wte.weight[self.start_override:self.start_override+self.length_override]
        self.initial_start_override = start_override
        self.initial_length_override = length_override
    def forward(self, tokens):
        """run forward pass
            tokens (torch.long): input tokens before encoding
            torch.float: encoding of text concatenated with learned task specifc embedding
        # Detect which tokens are not in range for the override, and prepare masks for them
        mask_below = (tokens < self.start_override) 
        mask_above = (tokens >= self.start_override + self.length_override)
        mask_out = torch.logical_or(mask_below, mask_above)

        embedded_tokens = self.wte(tokens)

        # Every token without representation has to be brought into appropriate range
        modified_tokens = tokens - self.start_override
        # Zero out the ones which already have pretrained embedding
        modified_tokens[mask_out] = 0
        # Get the
        embedded_tokens_after_override = self.wte_override(modified_tokens)

        # And finally change appropriate tokens from placeholder embedding created by
        # pretrained into trainable embeddings.
        #return embedded_tokens * torch.logical_not(mask_out) + embedded_tokens_after_override * mask_out
        embedded_tokens_after_override[mask_out] = embedded_tokens[mask_out]

        return embedded_tokens_after_override

    def commit_changes(self):
        with torch.no_grad():
            self.wte.weight[self.initial_start_override:self.initial_start_override+self.initial_length_override] = self.wte_override.weight[:].detach().clone()

The gist of the trick is to:

  1. freeze the embeddings layer of a pretrained model
  2. wrap that embedding layer in the one above
  3. replace the embedding layer of a pretrained model with the wrapped one
  4. train your model, only the embeddings in the wrapper will get trained
  5. commit the changes, so the wrapped embedding gets updated with the new values from the wrapper
  6. restore the wrapped embedding layer in the transformer before saving, by discarding the wrapper (you will not be able to save properly if your transformer still contains that wrapper embedding layer).

If you need to update more tokens that are not in a range, you will actually have to create multiple wrappers within each other. That would require changes in the code to support that.

If you use that to publish a paper using soft prompts, don’t hesitate to send me a copy, as I’m always curious about the topic! Feel free to add me in the acknowledgments if you have such a section in the paper, but you don’t have too, I license the above code as CC0.

1 Like

For those who stumble upon this thread in the future: there is now a library from :hugs: that does it: GitHub - huggingface/peft: 🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.

1 Like