Hi all,
I have a QA dataset for my company with both the questions and answers being similar enough for me to think that I could use a seq2seq model to suggest answers. A QA pair could be something like, “Hey, I forgot to cancel service, please refund me. Thanks, Sachin”, “Hi Sachin, No worries we will refund you”. Note there is no context paragraph.
I am following this blog post for EncodeDecoderModels
. I am able to follow most of it and implemented it, but the results for generation are looking pretty gibberish. And after 6 hours of training on a GPU the loss has only gone from 7 → 6. I’m wondering if I’m doing something wrong.
What I was thinking instead was to concatenate the QA pair with a “” string in the middle and use a BART/ GPT-2 model instead. Are there any examples of this. I know this github issue asked this a while back, but has there been any blogs/ tutorials on the subject?
This is the current model that I am using at the moment, which is a bert2bert model incase you see anything obviously wrong here. But otherwise I’m leaning more towards the option outlined in the paragraph above:
class Model(pl.LightningModule):
def __init__(self, lr: float) -> None:
super().__init__()
self.lr = lr
self.tokenizer = Tokenizer()
self.model = EncoderDecoderModel.from_encoder_decoder_pretrained(BASE_MODEL, BASE_MODEL)
self.initialize_hyper_parameters()
for name, param in self.model.named_parameters():
if "crossattention" not in name:
param.requires_grad = False
def initialize_hyper_parameters(self):
self.model.config.decoder_start_token_id = self.tokenizer.tokenizer.cls_token_id
self.model.config.eos_token_id = self.tokenizer.tokenizer.sep_token_id
self.model.config.pad_token_id = self.tokenizer.tokenizer.pad_token_id
self.model.config.vocab_size = self.model.config.encoder.vocab_size
self.model.config.max_length = 256
self.model.config.no_repeat_ngram_size = 3
self.model.config.early_stopping = True
self.model.config.length_penalty = 2.0
self.model.config.num_beams = 4
self.val_batch_count = 0
def common_step(self, batch: Tuple[List[str], List[str]]) -> torch.FloatTensor:
questions, answers = batch
question_tokens = {k: v.to(self.device) for k, v in self.tokenizer(questions).items()}
answer_tokens = {k: v.to(self.device) for k, v in self.tokenizer(answers).items()}
labels = answer_tokens["input_ids"].clone()
labels[answer_tokens["attention_mask"]==0] = -100
outputs = self.model(
input_ids=question_tokens["input_ids"],
attention_mask=question_tokens["attention_mask"],
decoder_input_ids=answer_tokens["input_ids"],
decoder_attention_mask=answer_tokens["attention_mask"],
labels=labels,
return_dict=True
)
return outputs["loss"]
def training_step(self, batch: Tuple[List[str], List[str]], *args) -> torch.FloatTensor:
loss = self.common_step(batch)
self.log(TRAIN_LOSS, loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch: Tuple[List[str], List[str]], *args) -> None:
loss = self.common_step(batch)
if self.val_batch_count == 0:
self.generate_examples(batch)
self.log(VALID_LOSS, loss, on_step=True, on_epoch=True)
self.val_batch_count += 1
def generate_examples(self, batch):
questions, answers = batch
question_tokens = {k: v.to(self.device) for k, v in self.tokenizer(questions).items()}
generated = self.model.generate(
input_ids=question_tokens["input_ids"],
attention_mask=question_tokens["attention_mask"],
# decoder_start_token_id=self.model.config.decoder.pad_token_id
)
self.tokenizer.decode(question_tokens["input_ids"][0])
print(self.tokenizer.decode(generated[0]))
def validation_step_end(self, *args):
self.val_batch_count = 0 # reset
def training_epoch_end(self, *args) -> None:
print("Unfreezing")
if self.current_epoch == FREEZE:
for name, param in self.model.named_parameters():
if "crossattention" not in name:
param.requires_grad = True
def configure_optimizers(self) -> torch.optim.Adam:
cross_attention_params = []
embedding_params = []
other_params = []
for name, param in self.model.named_parameters():
if "crossattention" in name:
cross_attention_params.append(param)
elif "embedding" in name:
embedding_params.append(param)
else:
other_params.append(param)
return torch.optim.Adam(
[
{"params": cross_attention_params, "lr": self.lr},
{"params": other_params, "lr": self.lr / 20},
{"params": embedding_params, "lr": self.lr / 100},
]
)