Hi all! I’m attempting to perform operations on the semantic meaning of a sentence by transforming hidden layer embeddings that I’m retrieving from a pretrained T5 model. However, I’m having difficulty decoding these embeddings once they’ve been modified. I’ve tried to write code that treats T5’s encoder and decoder as decoupled, but I haven’t yet been able to figure out how to generate text with the decoder that is conditioned on the modified embedding.
From my understanding, I’d need to re-implement beam search or top-k filtering in order to decode a sentence embedding, based on this post. Is that the case? How would I go about doing so?
Here’s some pseudocode I’ve tried to write so far, based on the previously linked post:
t5_model = transformers.T5ForConditionalGeneration.from_pretrained("t5-large")
t5_tok = transformers.T5Tokenizer.from_pretrained("t5-large")
text = "This is some example text."
tokenized = t5_tok(text, return_tensors="pt")
input_ids = tokenized.input_ids
attn_mask = tokenized.attention_mask
encoder_output_vectors = t5_model.encoder(input_ids, return_dict=True, attention_mask=attn_mask).last_hidden_state
encoder_output_vectors = some_vector_operation(encoder_output_vectors)
start_token = "<pad>"
eos_token = "</s>"
generated_so_far = [start_token]
while eos_token not in generated_so_far:
outputs = t5_model.decoder(decoder_input_ids=generated_so_far, encoder_hidden_states=encoder_output_vectors)
next_token_logits = outputs.last_hidden_state # outputs[0][0, -1, :]
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(torch.nn.functional.softmax(filtered_logits, dim=-1), num_samples=1)
generated_so_far = torch.cat((generated_so_far, next_token.unsqueeze(0)), dim=1)
The above should be using t5_tok.decode
somewhere as well, but I’m not sure exactly where.
Some examples of what I mean by “operations” on sentence embeddings: interpolating between two embeddings, finding an orthogonal sentence embedding, selecting from a random ball around an embedding, etc. The model doesn’t need to be T5; any pretrained encoder-decoder language model will do.
Many thanks for your support!