Choosing correct seq2seq model

Hi all,

I have a QA dataset for my company with both the questions and answers being similar enough for me to think that I could use a seq2seq model to suggest answers. A QA pair could be something like, “Hey, I forgot to cancel service, please refund me. Thanks, Sachin”, “Hi Sachin, No worries we will refund you”. Note there is no context paragraph.

I am following this blog post for EncodeDecoderModels. I am able to follow most of it and implemented it, but the results for generation are looking pretty gibberish. And after 6 hours of training on a GPU the loss has only gone from 7 → 6. I’m wondering if I’m doing something wrong.

What I was thinking instead was to concatenate the QA pair with a “” string in the middle and use a BART/ GPT-2 model instead. Are there any examples of this. I know this github issue asked this a while back, but has there been any blogs/ tutorials on the subject?

This is the current model that I am using at the moment, which is a bert2bert model incase you see anything obviously wrong here. But otherwise I’m leaning more towards the option outlined in the paragraph above:

class Model(pl.LightningModule):
    def __init__(self, lr: float) -> None:
        super().__init__()
        self.lr = lr
        self.tokenizer = Tokenizer()
        self.model = EncoderDecoderModel.from_encoder_decoder_pretrained(BASE_MODEL, BASE_MODEL)
        self.initialize_hyper_parameters()
        
        for name, param in self.model.named_parameters():
            if "crossattention" not in name:
                param.requires_grad = False
                
    def initialize_hyper_parameters(self):
        self.model.config.decoder_start_token_id = self.tokenizer.tokenizer.cls_token_id
        self.model.config.eos_token_id = self.tokenizer.tokenizer.sep_token_id
        self.model.config.pad_token_id = self.tokenizer.tokenizer.pad_token_id
        self.model.config.vocab_size = self.model.config.encoder.vocab_size

        self.model.config.max_length = 256
        self.model.config.no_repeat_ngram_size = 3
        self.model.config.early_stopping = True
        self.model.config.length_penalty = 2.0
        self.model.config.num_beams = 4
        
        self.val_batch_count = 0
        
    def common_step(self, batch: Tuple[List[str], List[str]]) -> torch.FloatTensor:
        questions, answers = batch
        question_tokens = {k: v.to(self.device) for k, v in self.tokenizer(questions).items()}
        answer_tokens = {k: v.to(self.device) for k, v in self.tokenizer(answers).items()}
        labels = answer_tokens["input_ids"].clone()
        labels[answer_tokens["attention_mask"]==0] = -100
        
        outputs = self.model(
            input_ids=question_tokens["input_ids"], 
            attention_mask=question_tokens["attention_mask"],
            decoder_input_ids=answer_tokens["input_ids"], 
            decoder_attention_mask=answer_tokens["attention_mask"],
            labels=labels, 
            return_dict=True
        )
        
        return outputs["loss"]
    
    def training_step(self, batch: Tuple[List[str], List[str]], *args) -> torch.FloatTensor:
        loss = self.common_step(batch)
        self.log(TRAIN_LOSS, loss, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch: Tuple[List[str], List[str]], *args) -> None:
        loss = self.common_step(batch)
        if self.val_batch_count == 0:
            self.generate_examples(batch)
        self.log(VALID_LOSS, loss, on_step=True, on_epoch=True)
        self.val_batch_count += 1
        
    def generate_examples(self, batch):
        questions, answers = batch
        question_tokens = {k: v.to(self.device) for k, v in self.tokenizer(questions).items()}

        generated = self.model.generate(
            input_ids=question_tokens["input_ids"], 
            attention_mask=question_tokens["attention_mask"], 
#             decoder_start_token_id=self.model.config.decoder.pad_token_id
        )

        self.tokenizer.decode(question_tokens["input_ids"][0])
        print(self.tokenizer.decode(generated[0]))
        
    def validation_step_end(self, *args):
        self.val_batch_count = 0 # reset
        
    def training_epoch_end(self, *args) -> None:
        print("Unfreezing")
        if self.current_epoch == FREEZE:
            for name, param in self.model.named_parameters():
                if "crossattention" not in name:
                    param.requires_grad = True
                    
                    
    def configure_optimizers(self) -> torch.optim.Adam:
        cross_attention_params = []
        embedding_params = []
        other_params = []
        
        for name, param in self.model.named_parameters():
            if "crossattention" in name:
                cross_attention_params.append(param)
            elif "embedding" in name:
                embedding_params.append(param)
            else:
                other_params.append(param)
        
        return torch.optim.Adam(
            [
                {"params": cross_attention_params, "lr": self.lr},
                {"params": other_params, "lr": self.lr / 20},
                {"params": embedding_params, "lr": self.lr / 100},
            ]
        )

Dear Sachin,

The models that we train are only as good as the data that we train them on. So start by thinking about your data.

The assumption of your dataset is that a sequence-to-sequence model can predict the answer to a question. In practice, sequence-to-sequence models were originally designed for translation. They were designed to predict an equivalent sequence.

So I would expect that a sequence-to-sequence model would only predict the correct answer if the question contains the information necessary to predict that correct answer.

Perhaps instead you could classify the answers? If so, you could train a model to predict the correct classification of the question.

Or you might look at Facebook’s BART model. It’s a sequence-to-sequence model, but at inference time, they use it to predict a classification.

And what’s really cool is that it’s a zero-shot classification. The model did not see the classifications during training. So in practice you might find many uses for the trained model.

Best wishes,
- Eryk