I have tried to implement KTO based on what i have understood on the paper, when i tried to train a gpt-2 model on “argilla/ultrafeedback-binarized-preferences-cleaned-kto”, the rewards are fluctuating and loss is also fluctuating. Can you help me to verfiy where i have made any mistakes. Please note that the code i have written is bare minimum for KTO.
for _ in range(epoch):
total_loss = 0
total_rewards = 0
for i in range(0,len(train_ds),batch_size):
batch = {
"prompt": train_ds['prompt'][i:i+batch_size],
"completion" : train_ds['completion'][i:i+batch_size],
"label" : train_ds['label'][i:i+batch_size],
"kl_data": train_ds['completion'][i:i+batch_size][1:] + train_ds['completion'][i:i+batch_size][:1]
}
tokenized_prompt = tokenizer(batch['prompt'],return_tensors='pt',padding=True, max_length=512,truncation=True)
tokenized_completion = tokenizer([x+y for x,y in zip(batch['prompt'],batch['completion'])],return_tensors='pt',padding=True, max_length=512,truncation=True)
tokenized_kl_data = tokenizer([x+y for x,y in zip(batch['prompt'],batch['kl_data'])],return_tensors='pt',padding=True, max_length=512,truncation=True)
concatenated_loss_mask = torch.cat((torch.zeros(tokenized_prompt['input_ids'].shape),
torch.ones((tokenized_completion['input_ids'].shape[0],
tokenized_completion['input_ids'].shape[1]-tokenized_prompt['input_ids'].shape[1]))),dim=-1).to(device)
concatenated_loss_mask_kl_data = torch.cat((torch.zeros(tokenized_prompt['input_ids'].shape),
torch.ones((tokenized_kl_data['input_ids'].shape[0],
tokenized_kl_data['input_ids'].shape[1]-tokenized_prompt['input_ids'].shape[1]))),dim=-1).to(device)
input_ids = tokenized_kl_data['input_ids'].to(device)
attention_mask = tokenized_kl_data['attention_mask'].to(device)
logtis = model(input_ids=input_ids,attention_mask=attention_mask).logits
# log_probs = logtis.log_softmax(dim=-1)[:,:-1,:] * concatenated_loss_mask_kl_data[:,1:].unsqueeze(-1) * attention_mask[:,1:].unsqueeze(-1)
probs = logtis.softmax(dim=-1)[:,:-1,:]
with torch.no_grad():
ref_logits = ref_model(input_ids=input_ids,attention_mask=attention_mask).logits
ref_probs = ref_logits.softmax(dim=-1)[:,:-1,:]
# ref_log_probs = ref_logits.log_softmax(dim=-1)[:,:-1,:] * concatenated_loss_mask_kl_data[:,1:].unsqueeze(-1) * attention_mask[:,1:].unsqueeze(-1)
# selected_ref_log_probs = torch.gather(ref_log_probs,-1,input_ids[:,1:].unsqueeze(-1)).squeeze()
KL_baseline = torch.max(torch.tensor(0).to(device),(probs * (torch.log((probs/ref_probs)+1e-9) * concatenated_loss_mask_kl_data[:,1:].unsqueeze(-1) * attention_mask[:,1:].unsqueeze(-1))).sum(dim=-1).sum(dim=-1).mean().detach())
# Generate rewards and values
input_ids = tokenized_completion['input_ids'].to(device)
attention_mask = tokenized_completion['attention_mask'].to(device)
logtis = model(input_ids=input_ids,attention_mask=attention_mask).logits
log_probs = logtis.log_softmax(dim=-1)[:,:-1,:] * concatenated_loss_mask[:,1:].unsqueeze(-1)* attention_mask[:,1:].unsqueeze(-1)
selected_log_probs = torch.gather(log_probs,-1,input_ids[:,1:].unsqueeze(-1)).squeeze()
with torch.no_grad():
ref_logits = ref_model(input_ids=input_ids,attention_mask=attention_mask).logits
ref_log_probs = ref_logits.log_softmax(dim=-1)[:,:-1,:] * concatenated_loss_mask[:,1:].unsqueeze(-1)* attention_mask[:,1:].unsqueeze(-1)
selected_ref_log_probs = torch.gather(ref_log_probs,-1,input_ids[:,1:].unsqueeze(-1)).squeeze()
rewards = selected_log_probs.sum(dim=-1) - selected_ref_log_probs.sum(dim=-1)
labels_tensor = torch.tensor(batch['label']).to(device)
loss = torch.where(labels_tensor,
(1-torch.sigmoid(beta*(rewards-KL_baseline))),
(1-5*torch.sigmoid(beta*(KL_baseline-rewards)))).mean()
total_loss+=loss.item()
total_rewards+=rewards.mean().item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
I have tried implementing both type of KL divergence, one taking all the KL’s for each token and summing them, the other one, taking only the selected token probabilities and using it to calculate the KL divergence, as they did in the hugging_face implementation of KTO. Please let me know what should i change
Thank you!