Custom embedding / prompt tuning

I’m trying to add learnable prompts to the embedding layer of a pre-trained T5 model. My naive attempt to is subclass the T5ForConditionalGeneration module and then adjust the input layer in the forward method. This doesn’t throw any errors, but I can’t learn the prompts; when I call model.swe.grad, it’s always None. Here is my code:

class myModel(T5ForConditionalGeneration):

    def __init__(self, n_tokens, 
                       wte, 
                       random_range: float = 0.5,
                       *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.wte = wte
        self.wte.requires_grad = False
        
        self.n_tokens = n_tokens
        self.swe = nn.parameter.Parameter(torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range),
                                          requires_grad=True)
    
    def forward(self, x, y):
        input_embedding = wte.forward(x)
        swe.repeat(input_embedding.size(0), 1, 1).shape
        beg = torch.cat([swe.repeat(input_embedding.size(0), 1, 1), input_embedding], 1)
        
        return super().forward(inputs_embeds=beg, labels=y)

# load models
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model_base = T5ForConditionalGeneration.from_pretrained('t5-small')

n_tokens = 20
wte = model_base.get_input_embeddings()

# custom model
model = myModel(n_tokens=20, wte=wte, config=model_base.config)

inputs = tokenizer("This is my sample ", return_tensors="pt")
labels = tokenizer("sentence", return_tensors="pt")

loss = model.forward(inputs.input_ids, labels.input_ids).loss
loss.backward()

print(model.swe.grad)
>>None

Any ideas how i can get the gradients for the prompts so that I can learn them?

4 Likes