I am using DPR models provided by the library to finetune on a custom dataset
passage_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
passage_model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
query_model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
class Huggingface_DPR(nn.Module):
def __init__(self, query_model, passage_model, query_tokenizer, passage_tokenizer,
passage_dict, questions, dense_size, freeze_params = 0.0, batch_size = 2,
sample_size = 4):
super(Huggingface_DPR, self).__init__()
self.query_model = query_model
self.query_tokenizer = query_tokenizer
self.passage_model = passage_model
self.passage_tokenizer = passage_tokenizer
self.freeze_params = freeze_params
self.sample_size = sample_size
self.batch_size = batch_size
self.passage_to_dense = nn.Sequential(nn.Linear(768, dense_size * 2),
nn.ReLU(),
nn.Linear(dense_size * 2, dense_size),
nn.GELU())
self.query_to_dense = nn.Sequential(nn.Linear(768, dense_size * 2),
nn.ReLU(),
nn.Linear(dense_size * 2, dense_size),
nn.GELU())
self.passage_dict = passage_dict
self.query_tuple = questions
self.log_softmax = nn.LogSoftmax(dim=1)
def batch_tokenize(self):
rand_idx = np.random.randint(0, len(self.passage_dict), (self.batch_size, self.sample_size))
queries = []
passages = []
true_idx = []
for row_idx,row in enumerate(rand_idx):
rand_query_idx = random.randint(0, len(self.query_tuple))
query, passage_id = self.query_tuple[rand_query_idx]
queries.append(query)
if passage_id not in row:
idx = random.randint(0, self.sample_size - 1)
rand_idx[row_idx][idx] = passage_id
true_idx.append(idx)
else:
true_idx.append(np.where(rand_idx[row_idx] == passage_id)[0][0])
for col_idx, col in enumerate(row):
passages.append(self.passage_dict[col])
passage_tensor = self.passage_tokenizer(passages, padding='longest', return_tensors="pt")
query_tensor = self.query_tokenizer(queries, padding='longest', return_tensors="pt")
return passage_tensor, query_tensor, true_idx
def dot_product(self, q_vector, p_vector):
q_vector = q_vector.unsqueeze(1)
sim = torch.matmul(q_vector, torch.transpose(p_vector, -2, -1))
return sim
def forward(self):
passage_tensor, query_tensor, true_idx = self.batch_tokenize()
passage_input_ids = passage_tensor.input_ids.reshape(self.batch_size, self.sample_size, -1)
passage_attention_mask = passage_tensor.attention_mask.reshape(self.batch_size, self.sample_size, -1)
dense_passage = self.passage_model(input_ids = passage_tensor.input_ids, attention_mask = passage_tensor.attention_mask)
dense_query = self.query_model(input_ids = query_tensor['input_ids'], attention_mask = query_tensor['attention_mask'])
dense_passage = dense_passage['pooler_output']
dense_passage = dense_passage.reshape(self.batch_size, self.sample_size, -1)
dense_query = dense_query['pooler_output']
dense_passage = self.passage_to_dense(dense_passage)
dense_query = self.query_to_dense(dense_query)
similarity_score = self.dot_product(dense_query, dense_passage)
similarity_score = similarity_score.squeeze(1)
log_scores = self.log_softmax(similarity_score)
return log_scores, torch.tensor(true_idx)
I am using the dot product as the similarity metric. Negative log-likelihood as the loss function.
In the batch_tokenize
method, I am implementing negative sampling of the given sample. size. The output of the batch_tokenize
method would be for passage_tensor
of size (batch_size, sample_size, padded_length)
and size of query_tensor
would be (batch_size, padded_length)
, the true_idx
is a list of length batch_size
This is the training loop that I am using.
## With Batch
for epo in range(5):
epoch_loss = 0
sum_loss = 0
for b in range(1, 100):
optimizer.zero_grad()
pred, true_idx = dpr_model()
loss = criterion(pred, true_idx)
epoch_loss += loss.item()
sum_loss += loss.item()
loss.backward()
optimizer.step()
if b%2 == 0:
print(f"Epoch : {epo + 1} Batch : {int(b)} Loss: {sum_loss/2}")
sum_loss = 0
print(f"Epoch {epo + 1} : Loss : {epoch_loss/20}")
Loss
Epoch : 1 Batch : 2 Loss: 2.083882689476013
Epoch : 1 Batch : 4 Loss: 2.078924059867859
Epoch : 1 Batch : 6 Loss: 2.0736374855041504
Epoch : 1 Batch : 8 Loss: 2.080022931098938
Epoch : 1 Batch : 10 Loss: 2.0795756578445435
Epoch : 1 Batch : 12 Loss: 2.084058165550232
Epoch : 1 Batch : 14 Loss: 2.079327940940857
Epoch : 1 Batch : 16 Loss: 2.0794405937194824
Epoch : 1 Batch : 18 Loss: 2.079430937767029
Epoch : 1 Batch : 20 Loss: 2.0794434547424316
Epoch : 1 Batch : 22 Loss: 2.0794490575790405
Epoch : 1 Batch : 24 Loss: 2.0794308185577393
Epoch : 1 Batch : 26 Loss: 2.0794389247894287
Epoch : 1 Batch : 28 Loss: 2.079446792602539
Epoch : 1 Batch : 30 Loss: 2.0794416666030884
Epoch : 1 Batch : 32 Loss: 2.0794419050216675
Epoch : 1 Batch : 34 Loss: 2.079440116882324
Epoch : 1 Batch : 36 Loss: 2.079442262649536
Epoch : 1 Batch : 38 Loss: 2.079442024230957
Epoch : 1 Batch : 40 Loss: 2.0794413089752197
Epoch : 1 Batch : 42 Loss: 2.079441547393799
Epoch : 1 Batch : 44 Loss: 2.0794419050216675
Epoch : 1 Batch : 46 Loss: 2.0794419050216675
...........................
There is no noticeable decrease in the loss from the beginning itself. Is there a mistake in the manner I am fine-tuning?