Does the Transformers library have an easy way to only finetune the embeddings of select few tokens in a Transformer model? (For example: the [unused1] [unused2] [unused3] tokens).
I want to try to generate “soft prompts” without updating the entire embedding layer of the Transformer.
I suppose that changing the model myself is always an option, but I wonder that the easiest way to achieve this would be, if this is the best option. I would like to be able to use this on different types of models (e.g. BERT and GPT), if this is possible.
@niklasstoehr Hi! As a matter of fact, I have written code that can achieve this, but it’s not 100% usable because I didn’t move along very far in that specific experiment. Feel free to use it as a base for your work, though
import torch
import torch.nn as nn
class PartialOverrideEmbedding(nn.Module):
def __init__(self,
wte: nn.Embedding,
start_override: int = 110, # [unused100] for my transformer
length_override: int = 800, # [unused900] for my transformer
initialize_from_vocab: bool = True):
"""appends learned embedding to
Args:
wte (nn.Embedding): original transformer word embedding
start_override (int, optional): first token id which will be trained separately. Defaults to 110 ([unused100] for BERT).
length_override (int, optional): how many tokens are to be trained separately after the first. Defaults to 800 ([unused900] for BERT).
initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
"""
super(PartialOverrideEmbedding, self).__init__()
self.start_override = start_override
self.length_override = length_override
self.wte = wte
self.wte_override = nn.Embedding(
length_override, wte.weight.shape[1]
)
if initialize_from_vocab:
self.wte_override.weight[:] = self.wte.weight[self.start_override:self.start_override+self.length_override]
self.initial_start_override = start_override
self.initial_length_override = length_override
def forward(self, tokens):
"""run forward pass
Args:
tokens (torch.long): input tokens before encoding
Returns:
torch.float: encoding of text concatenated with learned task specifc embedding
"""
# Detect which tokens are not in range for the override, and prepare masks for them
mask_below = (tokens < self.start_override)
mask_above = (tokens >= self.start_override + self.length_override)
mask_out = torch.logical_or(mask_below, mask_above)
embedded_tokens = self.wte(tokens)
# Every token without representation has to be brought into appropriate range
modified_tokens = tokens - self.start_override
# Zero out the ones which already have pretrained embedding
modified_tokens[mask_out] = 0
# Get the
embedded_tokens_after_override = self.wte_override(modified_tokens)
# And finally change appropriate tokens from placeholder embedding created by
# pretrained into trainable embeddings.
#return embedded_tokens * torch.logical_not(mask_out) + embedded_tokens_after_override * mask_out
embedded_tokens_after_override[mask_out] = embedded_tokens[mask_out]
return embedded_tokens_after_override
def commit_changes(self):
with torch.no_grad():
self.wte.weight[self.initial_start_override:self.initial_start_override+self.initial_length_override] = self.wte_override.weight[:].detach().clone()
The gist of the trick is to:
freeze the embeddings layer of a pretrained model
wrap that embedding layer in the one above
replace the embedding layer of a pretrained model with the wrapped one
train your model, only the embeddings in the wrapper will get trained
commit the changes, so the wrapped embedding gets updated with the new values from the wrapper
restore the wrapped embedding layer in the transformer before saving, by discarding the wrapper (you will not be able to save properly if your transformer still contains that wrapper embedding layer).
If you need to update more tokens that are not in a range, you will actually have to create multiple wrappers within each other. That would require changes in the code to support that.
If you use that to publish a paper using soft prompts, don’t hesitate to send me a copy, as I’m always curious about the topic! Feel free to add me in the acknowledgments if you have such a section in the paper, but you don’t have too, I license the above code as CC0.