T5 for Q&A (truncated sentences and long answers)

Dear all, I am fine-tuning T5 for Q&A task using the MedQuAD (GitHub - abachaa/MedQuAD: Medical Question Answering Dataset of 47,457 QA pairs created from 12 NIH websites) dataset. In the dataset, there are many long answers with thousands of words. I have used pytorch_lightning to train the T5-large model. I have two questions.

For example, I set both the max_length, max_input_length, max_output_length to 128.

  1. How to deal with those long answers? I just left them as is and the T5Tokenizer can automatically handle. I would assume the tokenizer just truncates an answer at the position of 128th word (or 127th). Is it possible that I manually split an answer into different parts, each part has 128 words; and then all these sub-answers serve as a separate answer to the same question?
  2. Another question is that I get incomplete (truncated) answers when using the fine-tuned model in inference, even though the predicted answer is shorter than 128 words. I found a message posted 2 years ago saying that one should add at the end of texts when fine-tuning T5. I followed that but then got a warning message that duplicated were found. I am assuming that this is because the tokenizer truncates an answer text, thus is missing in the truncated answer, such that the end token is not produced in predicted answer. However, I am not sure. Can anybody point out how to address this issue?

Any suggestions are highly appreciated.

Below is some code snippet.

import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch
import numpy as np
import time
from pathlib import Path

from transformers import (
    Adafactor,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)
from torch.utils.data import RandomSampler
from question_answering.utils import *


class T5FineTuner(pl.LightningModule):
    def __init__(self, hyparams):
        super(T5FineTuner, self).__init__()
        self.hyparams = hyparams
        self.model = T5ForConditionalGeneration.from_pretrained(hyparams.model_name_or_path)
        self.tokenizer = T5Tokenizer.from_pretrained(hyparams.tokenizer_name_or_path)

        if self.hyparams.freeze_embeds:
            self.freeze_embeds()
        if self.hyparams.freeze_encoder:
            self.freeze_params(self.model.get_encoder())
            # assert_all_frozen()

        self.step_count = 0
        self.output_dir = Path(self.hyparams.output_dir)

        n_observations_per_split = {
            'train': self.hyparams.n_train,
            'validation': self.hyparams.n_val,
            'test': self.hyparams.n_test
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
        self.em_score_list = []
        self.subset_score_list = []

        data_folder = r'C:\Datasets\MedQuAD-master'
        self.train_data, self.val_data, self.test_data = load_medqa_data(data_folder)

    def freeze_params(self, model):
        for param in model.parameters():
            param.requires_grad = False

    def freeze_embeds(self):
        try:
            self.freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                self.freeze_params(d.embed_positions)
                self.freeze_params(d.embed_tokens)
        except AttributeError:
            self.freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                self.freeze_params(d.embed_tokens)

    def lmap(self, f, x):
        return list(map(f, x))

    def is_logger(self):
        return self.trainer.proc_rank <= 0

    def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, labels=None):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels
        )

    def _step(self, batch):
        labels = batch['target_ids']
        labels[labels[:, :] == self.tokenizer.pad_token_id] = -100

        outputs = self(
            input_ids = batch['source_ids'],
            attention_mask=batch['source_mask'],
            labels=labels,
            decoder_attention_mask=batch['target_mask']
        )

        loss = outputs[0]

        return loss

    def ids_to_clean_text(self, generated_ids):
        gen_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return self.lmap(str.strip, gen_text)

    def _generative_step(self, batch):
        t0 = time.time()

        generated_ids = self.model.generate(
            batch["source_ids"],
            attention_mask=batch["source_mask"],
            use_cache=True,
            decoder_attention_mask=batch['target_mask'],
            max_length=128,
            num_beams=2,
            early_stopping=True
        )
        preds = self.ids_to_clean_text(generated_ids)
        targets = self.ids_to_clean_text(batch["target_ids"])

        gen_time = (time.time() - t0) / batch["source_ids"].shape[0]

        loss = self._step(batch)
        base_metrics = {'val_loss': loss}
        summ_len = np.mean(self.lmap(len, generated_ids))
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=targets)
        em_score, subset_match_score = calculate_scores(preds, targets)

        self.em_score_list.append(em_score)
        self.subset_score_list.append(subset_match_score)

        em_score = torch.tensor(em_score, dtype=torch.float32)
        subset_match_score = torch.tensor(subset_match_score, dtype=torch.float32)

        base_metrics.update(em_score=em_score, subset_match_score=subset_match_score)

        #         rouge_results = self.rouge_metric.compute()
        #         rouge_dict = self.parse_score(rouge_results)

        return base_metrics

    def training_step(self, batch, batch_idx):
        loss = self._step(batch)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def training_epoch_end(self, outputs):
        avg_train_loss = torch.stack([x['loss'] for x in outputs]).mean()
        tensorboard_logs = {'avg_train_loss': avg_train_loss}
        # return {'avg_train_loss': avg_train_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        return self._generative_step(batch)

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}

        if len(self.em_score_list) <= 2:
            average_em_score = sum(self.em_score_list) / len(self.em_score_list)
            average_subset_match_score = sum(self.subset_score_list) / len(self.subset_score_list)

        else:
            latest_em_score = self.em_score_list[:-2]
            latest_subset_score = self.subset_score_list[:-2]
            average_em_score = sum(latest_em_score) / len(latest_em_score)
            average_subset_match_score = sum(latest_subset_score) / len(latest_subset_score)

        average_em_score = torch.tensor(average_em_score, dtype=torch.float32)
        average_subset_match_score = torch.tensor(average_subset_match_score, dtype=torch.float32)
        tensorboard_logs.update(em_score=average_em_score, subset_match_score=average_subset_match_score)

        self.target_gen = []
        self.prediction_gen = []
        return {
            'avg_val_loss': avg_loss,
            'em_score': average_em_score,
            'subset_match_socre': average_subset_match_score,
            'log': tensorboard_logs,
            'progress_bar': tensorboard_logs
        }

    def configure_optimizers(self):
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hyparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]

        optimizer = Adafactor(optimizer_grouped_parameters, lr=self.hyparams.learning_rate, scale_parameter=False,
                              relative_step=False)
        self.opt = optimizer
        return [optimizer]

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure=None,
                       on_tpu=False, using_native_amp=False, using_lbfgs=False):
        optimizer.step(closure=optimizer_closure)
        optimizer.zero_grad()
        self.lr_scheduler.step()

    def get_tqdm_dict(self):
        tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}

        return tqdm_dict

    def train_dataloader(self):
        n_samples = self.n_obs['train']
        train_dataset = get_dataset(tokenizer=self.tokenizer, data=self.train_data, num_samples=n_samples,
                                    args=self.hyparams)
        sampler = RandomSampler(train_dataset)
        dataloader = DataLoader(train_dataset, sampler=sampler, batch_size=self.hyparams.train_batch_size,
                                drop_last=True, num_workers=4)
        # t_total = (
        #         (len(dataloader.dataset) // (self.hyparams.train_batch_size * max(1, self.hyparams.n_gpu)))
        #         // self.hyparams.gradient_accumulation_steps
        #         * float(self.hyparams.num_train_epochs)
        # )
        t_total = 100000
        scheduler = get_linear_schedule_with_warmup(
            self.opt, num_warmup_steps=self.hyparams.warmup_steps, num_training_steps=t_total
        )
        self.lr_scheduler = scheduler
        return dataloader

    def val_dataloader(self):
        n_samples = self.n_obs['validation']

        validation_dataset = get_dataset(tokenizer=self.tokenizer, data=self.val_data, num_samples=n_samples,
                                         args=self.hyparams)
        sampler = RandomSampler(validation_dataset)
        return DataLoader(validation_dataset, shuffle=False, batch_size=self.hyparams.eval_batch_size, sampler=sampler, num_workers=4)

    def test_dataloader(self):
        n_samples = self.n_obs['test']
        test_dataset = get_dataset(tokenizer=self.tokenizer, data=self.test_data, num_samples=n_samples, args=self.hyparams)

        return DataLoader(test_dataset, batch_size=self.hyparams.eval_batch_size, num_workers=4)

    def on_save_checkpoint(self, checkpoint):
        save_path = self.output_dir.joinpath("best_tfmr")
        self.model.config.save_step = self.step_count
        self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)
import os
import argparse
import pytorch_lightning as pl

from question_answering.t5_closed_book import T5FineTuner

if __name__ == '__main__':    
    args_dict = dict(
        output_dir="",  # path to save the checkpoints
        model_name_or_path='t5-large',
        tokenizer_name_or_path='t5-large',
        max_input_length=128,
        max_output_length=128,
        freeze_encoder=False,
        freeze_embeds=False,
        learning_rate=1e-5,
        weight_decay=0.0,
        adam_epsilon=1e-8,
        warmup_steps=0,
        train_batch_size=4,
        eval_batch_size=4,
        num_train_epochs=2,
        gradient_accumulation_steps=10,
        n_gpu=1,
        resume_from_checkpoint=None,
        val_check_interval=0.5,
        n_val=4000,
        n_train=-1,
        n_test=-1,
        early_stop_callback=False,
        fp_16=False,  
        opt_level='O1',        
        max_grad_norm=1.0,  
        seed=101,
    )

    args_dict.update({'output_dir': 't5_large_MedQuAD_256', 'num_train_epochs': 100,
                      'train_batch_size': 16, 'eval_batch_size': 16, 'learning_rate': 1e-3})                      
    args = argparse.Namespace(**args_dict)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=args.output_dir, monitor="em_score", mode="max", save_top_k=1)

    ## If resuming from checkpoint, add an arg resume_from_checkpoint
    train_params = dict(
        accumulate_grad_batches=args.gradient_accumulation_steps,
        gpus=args.n_gpu,
        max_epochs=args.num_train_epochs,
        # early_stop_callback=False,
        precision=16 if args.fp_16 else 32,
        # amp_level=args.opt_level,
        # resume_from_checkpoint=args.resume_from_checkpoint,
        gradient_clip_val=args.max_grad_norm,
        checkpoint_callback=checkpoint_callback,
        val_check_interval=args.val_check_interval,
        # accelerator='dp'
        # logger=wandb_logger,
        # callbacks=[LoggingCallback()],
    )

    model = T5FineTuner(args)
    trainer = pl.Trainer(**train_params)

    trainer.fit(model)