How to extend model.generate() to accept additional parameters to be used by the forward of Llama

Hello everyone.

I’m working with Llama 3.1 8B and in particular this LLM is accepting as input the concatenation of audio and text tokens to perform the ASR task. I have modified the Llama class in Transformers to deal with audio and text tokens separately through LoRA modules. In order to distinguish audio and text tokens, the forward method accepts an index that specifies the index of the first text token in order to slice the input tokens into audio and text tokens. For example:

outputs = self.llm(inputs_embeds = embeddings, labels = labels, index = index)

So far so good. The problem is that in inference, I generate the text like:

decoded_ids = self.llm.generate(inputs_embeds = embeddings,...)

How can I modify the generate method such that the call to the LLM includes the “index” parameter?

Maybe @ArthurZ can help me out with this? Thank you!!

1 Like