I am using GPT2LMHeadModel model but want to skip embedding layers of this model, and i will also be using the model.generate function for text generation task. is there any way to do this?
My purpose to skip embedding is to use embedding vectors from some other model.
Hey @dharmendra Soft prompting was added very recently! You can read it in more detail on this GitHub comment, but TL;DR
- Install transformers from main, as this feature will only be in transformers v4.27 (
pip install --upgrade git+https://github.com/huggingface/transformers.git
) - You can now pass
inputs_embeds
to.generate()
, which will be used as the prompt.
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
text = "Hello world"
input_ids = tokenizer.encode(text, return_tensors="pt")
# Traditional way of generating text
outputs = model.generate(input_ids)
print("\ngenerate + input_ids:", tokenizer.decode(outputs[0], skip_special_tokens=True))
# From inputs_embeds -- exact same output if you also pass `input_ids`. If you don't
# pass `input_ids`, you will get the same generated content but without the prompt
inputs_embeds = model.transformer.wte(input_ids)
outputs = model.generate(input_ids, inputs_embeds=inputs_embeds)
print("\ngenerate + inputs_embeds:", tokenizer.decode(outputs[0], skip_special_tokens=True))
I have transformers of version 2.11.0 and i gave âinputs_embedâ as a parameter and it generate some sequence without any error, so i think it works on transformers below the version you mentioned.
thank you very much for the response, it saved lot of my crucial time , thanks again
one more question, I donât want to use âgenerateâ function with pre-trained model but with my custom model.
I tried to import GenerationMixIn class and use the generation function but it says import error.
can you plz suggest some way to deal with this issue.
@dharmendra If your model inherits from PreTrainedModel
, it should work out of the box. If it doesnât work, please open an issue in transformers