I have this decoder where I have used pre-trained bert. Training looks fine but I am having problem in inference. I wonder if my approach of greedy generation in the function below is correct.
class TextDecoder(nn.Module):
def __init__(self, tokenizer, config, device):
super().__init__()
self.device=device
self.config = config
self.config["add_cross_attention"] = True
self.config["is_decoder"] = True
self.tokenizer = tokenizer
self.decoder_config = config
self.decoder_config = BertConfig.from_dict(self.config)
self.decoder = BertLMHeadModel.from_pretrained('bert-base-uncased', config=self.decoder_config).to(self.device)
self.decoder.config.decoder_start_token_id = tokenizer.cls_token_id
self.decoder.config.pad_token_id = tokenizer.pad_token_id
def forward(self, input_ids, attention_mask, enc_out, encoder_attention_mask):
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
encoder_attention_mask = encoder_attention_mask.to(self.device)
decoder_input_ids = shift_tokens_right(input_ids, self.tokenizer.pad_token_id, self.tokenizer.cls_token_id)
# decoder_attention_mask = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
decoder_out = self.decoder(input_ids = decoder_input_ids, attention_mask=attention_mask, encoder_hidden_states=enc_out, labels=decoder_input_ids,
encoder_attention_mask=encoder_attention_mask, return_dict=True)
return decoder_out.loss, decoder_out.logits
def generate(self, input_ids, encoder_hidden_states, do_sample=False, max_length=14):
outputs = text_decoder.decoder.generate(input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
do_sample=False,
max_length=14)
# take max from logits and decode it
output_sen = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return output_sen
def train(data_loader):
text_encoder.train()
text_decoder.train()
epoch_loss = 0
with torch.enable_grad():
for i, (in_ids, in_msk, out_ids, out_msk) in enumerate(data_loader):
max_len = param['max_length']
batch_size = param['batch_size']
vocab_size = param['vocab_size']
ls_labels = out_ids.clone()
optimizer.zero_grad()
encoder_out = text_encoder(in_ids, in_msk)
loss_, logits = text_decoder(out_ten, out_mask, encoder_out, encoder_attention_mask=in_mask)
loss_.backward()
optimizer.step()
epoch_loss += loss_.item()
return epoch_loss/len(data_loader)
history = {'train_loss':[], 'val_loss':[]}
for epoch in range(20):
train_loss = train(train_dataloader)
val_loss = train(val_dataloader)
history['train_loss'].append(train_loss)
history['val_loss'].append(val_loss)
print(f'Epoch:{epoch} Train Loss: {train_loss} Val Loss: {val_loss}')
My loss function:
This is my Generation during inference
:
def generate_greedy(in_ten, in_mask):
'''Two approaches might be to take the logits, convert them to a distribution and (1) take argmax or (2) sample from the distribution.'''
text_encoder.eval()
text_decoder.eval()
text_src = in_ten.to(device)
text_mask = in_mask.to(device)
text_feats = text_encoder(in_seq=text_src, in_mask=text_mask)
input_ids = torch.zeros((text_feats.shape[0], 1)).long().to(device)
input_ids[:, 0] = 101
output_sen = text_decoder.generate(input_ids, text_feats)
input_sen = tokenizer.batch_decode(in_ten, skip_special_tokens=True)
return input_sen, output_sen
But the output is:
Input sentence: a group of men are loading cotton onto a truck
Output sentence: a man in a suit is running past two other gentleman also also
******************************
Input sentence: a man sleeping in a green room on a couch
Output sentence: a man in a suit is running past two other gentleman also also
******************************
Input sentence: a boy wearing headphones sits on a womans shoulders
Output sentence: a man in a suit is running past two other gentleman also also
******************************
Input sentence: two men setting up a blue ice fishing hut on an iced
Output sentence: a man in a suit is running past two other gentleman also also
******************************
Input sentence: a balding man wearing a red life jacket is sitting in
Output sentence: a man in a suit is running past two other gentleman also also