Injecting multi modal embeddings into a language model breaks the `generate` function

I’m following an approach similar to the LLaVa model where I project audio embeddings into the language model space. I add two special tokens to the language model tokenizer and inject 300 audio embeddings between these special tokens before passing the new embedding sequence, attention mask, labels and position IDs to the forward function.

Now, the training works great for Mistral, Gemma, Llama, Qwen. However, all of these seem to break after the model is trained and I call the generate function like so in the code below. Even without injecting the audio embeddings (like below) it seems that the training completely breaks passing in any inputs_embeds and I get different errors for each model so assume I’m missing something fundamental here.

prompt = f"""<|im_start|>system
            You are a helpful assistant<|im_end|>
            <|im_start|>user
            query<|im_end|>
            <|im_start|>assistant"""

tokenizer_output = model.tokenizer(
    prompt,
    return_tensors='pt',
    truncation=True,
)

input_ids = tokenizer_output['input_ids'].to(model.device)
attention_mask = tokenizer_output['attention_mask'].to(model.device)
batch_size, sequence_length = input_ids.shape

max_seq_length = sequence_length

inputs_embeds = model.language_model.get_input_embeddings()(input_ids).to(dtype=torch.bfloat16)

position_ids = (attention_mask.cumsum(-1) - 1).masked_fill((attention_mask == 0), 1).long()

top_k = 2
top_p = None

generation_params = {
    "inputs_embeds": inputs_embeds,
    "attention_mask": attention_mask,
    "position_ids": position_ids,
    "max_new_tokens": 256,
    "do_sample": True,
    "temperature": 0.9,
}

if top_k is not None:
    generation_params["top_k"] = top_k
if top_p is not None:
    generation_params["top_p"] = top_p

output_ids = model.language_model.generate(**generation_params)
result = model.tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(result)

Here’s my model code and some error examples for different text models:

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig, MusicgenConfig, MusicgenForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType

class CustomLlarkModel(nn.Module):
    def __init__(self, model_name, model_type, device, use_lora=True):
        super(CustomLlarkModel, self).__init__()

        self.model_type = model_type
        self.device = device
        self.target_sr = 32000

        self.musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
        musicgen_config = AutoConfig.from_pretrained("facebook/musicgen-small")
        self.musicgen_model = MusicgenForConditionalGeneration(musicgen_config).to(self.device)

        self.hidden_size = musicgen_config.decoder.hidden_size

        # Define special tokens for audio start and end
        self.AUDIO_START_TOKEN = "<AUDIO_START>"
        self.AUDIO_END_TOKEN = "<AUDIO_END>"

        # Initialize tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.language_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto").to(self.device)

        if use_lora:
            # Configure LoRA
            peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=64, lora_alpha=128, lora_dropout=0, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])

            self.language_model = get_peft_model(
                self.language_model,
                peft_config
            )

            self.language_model.gradient_checkpointing_enable()

        # Add special tokens to tokenizer
        self.tokenizer.add_tokens([self.AUDIO_START_TOKEN, self.AUDIO_END_TOKEN])
        self.language_model.resize_token_embeddings(len(self.tokenizer))

        # Get token ids for special tokens
        self.audio_start_token_id = self.tokenizer.convert_tokens_to_ids(self.AUDIO_START_TOKEN)
        self.audio_end_token_id = self.tokenizer.convert_tokens_to_ids(self.AUDIO_END_TOKEN)

        # Simple projection layer to map audio embeddings to Llama input space
        self.audio_projection = nn.Linear(1024, self.language_model.config.hidden_size).to(device)

        self.freeze_unwanted_components()

        if self.model_type == "llama":
          self.start_of_header_token = 128006
          self.eot_token = 128001
        elif self.model_type == "gemma":
          self.start_of_turn_token = 106
          self.eot_token = self.tokenizer.eos_token_id # 1
        elif self.model_type == "qwen2":
          self.start_of_turn_token = 151644
          self.eot_token = self.tokenizer.pad_token_id # 1

    def get_prompt(self, query, answer):
        if self.model_type == "llama":
            return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
            You are a helpful AI assistant, you're given audio encoded as a sequence of 300 tokens below and must transcribe it precisely.
            {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN}<|eot_id|><|start_header_id|>user<|end_header_id|>
            {query}<|eot_id|>
            <|start_header_id|>assistant<|end_header_id|>
            {answer}<|eot_id|><|end_of_text|>"""
        elif self.model_type == "gemma":
            return f"""<bos><start_of_turn>user
            You're given audio encoded as a sequence of 300 tokens: {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN} {query}<end_of_turn>
            <start_of_turn>model
            {answer}<end_of_turn><eos>"""
        elif self.model_type == "qwen2":
            return f"""<|im_start|>system
            You are a helpful AI assistant, you're given the following audio encoded as a sequence of 300 tokens and must transcribe it precisely. {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN}<|im_end|>
            <|im_start|>user
            {query}<|im_end|>
            <|im_start|>assistant
            {answer}<|im_end|><|endoftext|>"""

    def freeze_unwanted_components(self):
        for param in self.musicgen_model.parameters():
        # for param in self.musicgen_model.text_encoder.parameters():
            param.requires_grad = False
        
    def forward(self, **kwargs):
        audio = kwargs.get('audio')
        queries = kwargs.get('query')
        answers = kwargs.get('answer')

        inputs = self.musicgen_processor(audio=audio, sampling_rate=self.target_sr, return_tensors="pt", padding=True)
        input_values_tensor = inputs["input_values"].to(self.device)
        padding_mask_tensor = inputs["padding_mask"].to(self.device)

        with torch.no_grad():
            # Get the encoder outputs for the batch
            encoder_outputs = self.musicgen_model.audio_encoder.encode(input_values_tensor, padding_mask_tensor)
            decoded_representation = self.musicgen_model.decoder(encoder_outputs.audio_codes, output_hidden_states=True)

        # Get all hidden states
        all_hidden_states = decoded_representation.hidden_states[-1]

        # Downsample to 10Hz (one embedding per 100ms)
        # Assuming original timestep is 20ms, we need one out of every 5 embeddings
        downsample_rate = 5  # 100ms / 20ms = 5
        downsampled_hidden_states = all_hidden_states[:, ::downsample_rate, :]
        number_of_audio_tokens = downsampled_hidden_states.shape[1]

        batch_size = len(queries)
        # print(f"Input audio embedding shape: {downsampled_hidden_states.shape}")

        # Project audio embeddings to match Llama's hidden size
        audio_features = self.audio_projection(downsampled_hidden_states).to(self.device)
        # print(f"Projected audio shape: {projected_audio.shape}")

        prompts = [self.get_prompt(query, answer) for query, answer in zip(queries, answers)]

        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"

        # Tokenize the input without padding
        tokenizer_output = self.tokenizer(
            prompts,
            return_tensors='pt',
            truncation=True,
            padding=True,
            # max_length=max_length_without_audio,
            # pad_to_multiple_of=max_length_without_audio
        )

        # Get the tokenized input ids and attention mask
        input_ids = tokenizer_output['input_ids'].to(self.device)
        attention_mask = tokenizer_output['attention_mask'].to(self.device)
        
        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
        # print(f"Initial input embeddings shape: {inputs_embeds.shape}")

        batch_size, sequence_length = input_ids.shape
        max_seq_length = sequence_length + number_of_audio_tokens

        new_inputs_embeds = torch.zeros(
            (batch_size, max_seq_length, inputs_embeds.size(2)),
            device=inputs_embeds.device
        )
        new_attention_mask = torch.zeros((batch_size, max_seq_length), device=inputs_embeds.device)

        # Insert the projected audio embeddings between the AUDIO_START_TOKEN and AUDIO_END_TOKEN for each sample in the batch
        for i in range(batch_size):
            start_pos = (input_ids[i] == self.audio_start_token_id).nonzero(as_tuple=True)[0]
            end_pos = (input_ids[i] == self.audio_end_token_id).nonzero(as_tuple=True)[0]

            if not len(start_pos) == 1:
                raise ValueError(f"Incorrect number of audio start tokens in the input. Got {len(start_pos)} tokens")

            if not len(end_pos) == 1:
                raise ValueError(f"Incorrect number of audio end tokens in the input. Got {len(end_pos)} tokens")

            if start_pos.size(0) > 0 and end_pos.size(0) > 0:
                start_pos = start_pos[0].item()
                end_pos = end_pos[0].item()

                # Create the new embedding sequence
                part1 = inputs_embeds[i, :start_pos + 1]
                part2 = audio_features[i]
                part3 = inputs_embeds[i, end_pos:]

                new_embed = torch.cat((part1, part2, part3), dim=0)

                # new_inputs_embeds[i] = new_embed[:max_seq_length]
                new_inputs_embeds[i] = new_embed

                # Adjust attention mask for the inserted audio embeddings
                new_attention_mask[i] = torch.cat((
                    attention_mask[i, :start_pos + 1],
                    torch.ones(number_of_audio_tokens, device=inputs_embeds.device),
                    attention_mask[i, end_pos:]
                ), dim=0)

        position_ids = (new_attention_mask.cumsum(-1) - 1).masked_fill_((new_attention_mask == 0), 1).long()

        labels_list = []

        for i in range(input_ids.size(0)):
            # Create a copy of input_ids to serve as the base for labels
            sample_labels = torch.full_like(input_ids[i], -100)

            if self.model_type == "llama":
                # Get the third start_of_header_token which is the one where the assistant's response starts
                assistant_start_pos = (input_ids[i] == self.start_of_header_token).nonzero(as_tuple=True)[0][2].item()
            elif self.model_type == "gemma":
                # Get the second start_of_turn_token which is the one where the model's response starts then +1
                assistant_start_pos = (input_ids[i] == self.start_of_turn_token).nonzero(as_tuple=True)[0][1].item() + 1
            elif self.model_type == "qwen2":
                # Get the third start_of_turn_token which is the one where the model's response starts then +1
                assistant_start_pos = (input_ids[i] == self.start_of_turn_token).nonzero(as_tuple=True)[0][2].item() + 1

            eot_pos = (input_ids[i] == self.eot_token).nonzero(as_tuple=True)[0].item()

            # Fill in the input_ids for the assistant response part
            sample_labels[assistant_start_pos:eot_pos] = input_ids[i, assistant_start_pos:eot_pos]

            # Find AUDIO_START_TOKEN and AUDIO_END_TOKEN positions
            audio_start_pos = (input_ids[i] == self.audio_start_token_id).nonzero(as_tuple=True)[0].item()
            audio_end_pos = (input_ids[i] == self.audio_end_token_id).nonzero(as_tuple=True)[0].item()

            # Insert 300 padding values between AUDIO_START_TOKEN and AUDIO_END_TOKEN for each sample separately
            sample_labels = torch.cat((
                sample_labels[:audio_start_pos + 1],
                torch.full((number_of_audio_tokens,), -100, dtype=input_ids.dtype, device=input_ids.device),
                sample_labels[audio_start_pos + 1:audio_end_pos + 1],
                sample_labels[audio_end_pos + 1:]
            ), dim=0)

            # Verify the length after concatenation
            expected_length = input_ids.size(1) + number_of_audio_tokens
            if sample_labels.size(0) != expected_length:
                raise ValueError(f"Concatenation error: expected length {expected_length} but got {sample_labels.size(0)}")

            labels_list.append(sample_labels)

        labels = torch.stack(labels_list)

        print(f"Labels shape: {labels.shape}")
        print(f"Number of tokens used as labels: {(labels != -100).sum().item()}")

        assert max_seq_length == position_ids.size(1) == new_inputs_embeds.size(1) == labels.size(1) == new_attention_mask.size(1), "position_ids, new_inputs_embeds, new_labels and new_attention_mask must have the same sequence length equal to the max_seq_length"

        outputs = self.language_model(
            inputs_embeds=new_inputs_embeds,
            position_ids=position_ids,
            attention_mask=new_attention_mask,
            labels=labels,
            return_dict=True
        )

        # print(f"Model output keys: {outputs.keys()}")
        if 'loss' in outputs:
            print(f"Loss: {outputs.loss.item()}")

        return outputs