Decoding Modified Sentence Embeddings

Hi all! I’m attempting to perform operations on the semantic meaning of a sentence by transforming hidden layer embeddings that I’m retrieving from a pretrained T5 model. However, I’m having difficulty decoding these embeddings once they’ve been modified. I’ve tried to write code that treats T5’s encoder and decoder as decoupled, but I haven’t yet been able to figure out how to generate text with the decoder that is conditioned on the modified embedding.

From my understanding, I’d need to re-implement beam search or top-k filtering in order to decode a sentence embedding, based on this post. Is that the case? How would I go about doing so?

Here’s some pseudocode I’ve tried to write so far, based on the previously linked post:

t5_model = transformers.T5ForConditionalGeneration.from_pretrained("t5-large")
t5_tok = transformers.T5Tokenizer.from_pretrained("t5-large")

text = "This is some example text."
tokenized = t5_tok(text, return_tensors="pt")
input_ids = tokenized.input_ids
attn_mask = tokenized.attention_mask
encoder_output_vectors = t5_model.encoder(input_ids, return_dict=True, attention_mask=attn_mask).last_hidden_state

encoder_output_vectors = some_vector_operation(encoder_output_vectors)

start_token = "<pad>"
eos_token = "</s>"
generated_so_far = [start_token]
while eos_token not in generated_so_far:
    outputs = t5_model.decoder(decoder_input_ids=generated_so_far, encoder_hidden_states=encoder_output_vectors)
    next_token_logits = outputs.last_hidden_state # outputs[0][0, -1, :] 
    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
    next_token = torch.multinomial(torch.nn.functional.softmax(filtered_logits, dim=-1), num_samples=1) 
    generated_so_far = torch.cat((generated_so_far, next_token.unsqueeze(0)), dim=1)

The above should be using t5_tok.decode somewhere as well, but I’m not sure exactly where.

Some examples of what I mean by “operations” on sentence embeddings: interpolating between two embeddings, finding an orthogonal sentence embedding, selecting from a random ball around an embedding, etc. The model doesn’t need to be T5; any pretrained encoder-decoder language model will do.

Many thanks for your support!

1 Like