Verify the correctness of implementation of KTO

I have tried to implement KTO based on what i have understood on the paper, when i tried to train a gpt-2 model on “argilla/ultrafeedback-binarized-preferences-cleaned-kto”, the rewards are fluctuating and loss is also fluctuating. Can you help me to verfiy where i have made any mistakes. Please note that the code i have written is bare minimum for KTO.

for _ in range(epoch):
    total_loss = 0
    total_rewards = 0 
    for i in range(0,len(train_ds),batch_size):
        
        batch = {
            "prompt": train_ds['prompt'][i:i+batch_size],
            "completion" : train_ds['completion'][i:i+batch_size],
            "label" : train_ds['label'][i:i+batch_size],
            "kl_data": train_ds['completion'][i:i+batch_size][1:] + train_ds['completion'][i:i+batch_size][:1]
        }

        tokenized_prompt = tokenizer(batch['prompt'],return_tensors='pt',padding=True, max_length=512,truncation=True)
        tokenized_completion = tokenizer([x+y for x,y in zip(batch['prompt'],batch['completion'])],return_tensors='pt',padding=True, max_length=512,truncation=True)
        tokenized_kl_data = tokenizer([x+y for x,y in zip(batch['prompt'],batch['kl_data'])],return_tensors='pt',padding=True, max_length=512,truncation=True)
        concatenated_loss_mask = torch.cat((torch.zeros(tokenized_prompt['input_ids'].shape),
                                            torch.ones((tokenized_completion['input_ids'].shape[0],
                                                        tokenized_completion['input_ids'].shape[1]-tokenized_prompt['input_ids'].shape[1]))),dim=-1).to(device)

        concatenated_loss_mask_kl_data = torch.cat((torch.zeros(tokenized_prompt['input_ids'].shape),
                                            torch.ones((tokenized_kl_data['input_ids'].shape[0],
                                                        tokenized_kl_data['input_ids'].shape[1]-tokenized_prompt['input_ids'].shape[1]))),dim=-1).to(device)

        input_ids = tokenized_kl_data['input_ids'].to(device)
        attention_mask = tokenized_kl_data['attention_mask'].to(device)
        
        logtis = model(input_ids=input_ids,attention_mask=attention_mask).logits
        # log_probs = logtis.log_softmax(dim=-1)[:,:-1,:] * concatenated_loss_mask_kl_data[:,1:].unsqueeze(-1) * attention_mask[:,1:].unsqueeze(-1)
        probs = logtis.softmax(dim=-1)[:,:-1,:]
        with torch.no_grad():
            ref_logits = ref_model(input_ids=input_ids,attention_mask=attention_mask).logits
            ref_probs = ref_logits.softmax(dim=-1)[:,:-1,:]
            # ref_log_probs = ref_logits.log_softmax(dim=-1)[:,:-1,:] * concatenated_loss_mask_kl_data[:,1:].unsqueeze(-1) * attention_mask[:,1:].unsqueeze(-1)
            # selected_ref_log_probs = torch.gather(ref_log_probs,-1,input_ids[:,1:].unsqueeze(-1)).squeeze()

        
        
        KL_baseline = torch.max(torch.tensor(0).to(device),(probs * (torch.log((probs/ref_probs)+1e-9) * concatenated_loss_mask_kl_data[:,1:].unsqueeze(-1) * attention_mask[:,1:].unsqueeze(-1))).sum(dim=-1).sum(dim=-1).mean().detach())
        
        

        # Generate rewards and values
        input_ids = tokenized_completion['input_ids'].to(device)
        attention_mask = tokenized_completion['attention_mask'].to(device)
        
        logtis = model(input_ids=input_ids,attention_mask=attention_mask).logits
        log_probs = logtis.log_softmax(dim=-1)[:,:-1,:] * concatenated_loss_mask[:,1:].unsqueeze(-1)* attention_mask[:,1:].unsqueeze(-1)
        selected_log_probs = torch.gather(log_probs,-1,input_ids[:,1:].unsqueeze(-1)).squeeze()

        with torch.no_grad():
            ref_logits = ref_model(input_ids=input_ids,attention_mask=attention_mask).logits
            ref_log_probs = ref_logits.log_softmax(dim=-1)[:,:-1,:] * concatenated_loss_mask[:,1:].unsqueeze(-1)* attention_mask[:,1:].unsqueeze(-1)
            selected_ref_log_probs = torch.gather(ref_log_probs,-1,input_ids[:,1:].unsqueeze(-1)).squeeze()
        
        rewards = selected_log_probs.sum(dim=-1) - selected_ref_log_probs.sum(dim=-1)
        labels_tensor = torch.tensor(batch['label']).to(device)
        
        loss = torch.where(labels_tensor,
                           (1-torch.sigmoid(beta*(rewards-KL_baseline))),
                           (1-5*torch.sigmoid(beta*(KL_baseline-rewards)))).mean()
        total_loss+=loss.item()
        total_rewards+=rewards.mean().item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

I have tried implementing both type of KL divergence, one taking all the KL’s for each token and summing them, the other one, taking only the selected token probabilities and using it to calculate the KL divergence, as they did in the hugging_face implementation of KTO. Please let me know what should i change

Thank you!

1 Like