How to inject condition into causal model correctly

I am working with AutoModelForCausalLM from transformers and I want to finetune pretrained model with conditions(like finetuning an llm with the condition of picture vectors). I think concatenating an additional embedding before the whole sequence of tokens should work, but it turns out that the condition is ‘useless’.

Here’s what I’ve done already with a very simple model on a very minimal dataset:

First of all, the dataset only consists of 2 sentences each with 2 tokens: 0 0 and 1 1, no <bos>, no <eos> and no <pad>. What I want is to predict the second token according to the first token(condition).

Secondly, I trained the model on this dataset, only to predict the second token given the first token:

input_ids = torch.tensor([
    [0, 0],
    [1, 1],
], device='cuda').long()

# convert token into embeddings
inputs_embeds = self.transformer.get_input_embeddings()(input_ids)

# prepare atten mask
attention_mask = torch.ones((2, 2), device='cuda')

output = self.transformer(
    inputs_embeds=inputs_embeds,
    attention_mask=attention_mask,
)

# compute loss with shift one-token right
logit = output.logits[:, :-1]
label = torch.tensor([[0], [1]], device='cuda')

loss = nn.functional.cross_entropy(
    logit.permute(0, 2, 1),
    label,
)

The training loss rapidly dropped to 1e^-7, as expected.

Thirdly, I try to modify the first token’s embedding after getting the word embeddings:

# inject condition?
inputs_embeds[0, 0, :] = 0.
inputs_embeds[1, 0, :] = 1.

And the training loss is stuck at 0.6~0.7 and the output is randomly 0 or 1. So the condition here seems useless

How to fix this code? Thanks for any hints!

Full code:

import torch
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM

class MyModel(nn.Module):

    def __init__(self):
        super().__init__()
        config = AutoConfig.from_pretrained(
            'facebook/opt-125m',
            n_positions=8192,
            max_position_embeddings=8192,
            vocab_size=5,
            bos_token_id=2,
            eos_token_id=3,
            pad_token_id=4,
        )
        self.transformer = AutoModelForCausalLM.from_config(config=config)
    
    def forward(self,):
        return self.train_one_step()
    
    def train_one_step(self,) -> dict:
        input_ids = torch.tensor([
            [0, 0],
            [1, 1],
        ], device='cuda').long()
        
        # convert token into embeddings
        inputs_embeds = self.transformer.get_input_embeddings()(input_ids)
        # shape (batch, ntoken, hidden)
        
        # inject condition
        inputs_embeds[0, 0, :] = 0.
        inputs_embeds[1, 0, :] = 1.

        # prepare atten mask
        attention_mask = torch.ones((2, 2), device='cuda')
        
        output = self.transformer(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
        )
        
        # compute loss with shift one-token right
        logit = output.logits[:, :-1]
        label = torch.tensor([[0], [1]], device='cuda')
        
        # debug info
        print(label)
        print('max',torch.argmax(logit, dim=2))
        
        loss = nn.functional.cross_entropy(
            logit.permute(0, 2, 1),
            label,
        )
        
        return loss

model = MyModel().to('cuda')
# optimize boilerplate
epochs = 100
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001, weight_decay=0.001)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
            max_lr=0.001, pct_start=0.05, anneal_strategy="cos",
            div_factor=10.0, final_div_factor=1000.0,
            epochs=epochs,
            steps_per_epoch=2)

for i in range(epochs):
    loss = model.train_one_step()
    optimizer.zero_grad()
    print(loss)
    loss.backward()
    optimizer.step()
    scheduler.step()