Dear all, I am fine-tuning T5 for Q&A task using the MedQuAD (GitHub - abachaa/MedQuAD: Medical Question Answering Dataset of 47,457 QA pairs created from 12 NIH websites) dataset. In the dataset, there are many long answers with thousands of words. I have used pytorch_lightning to train the T5-large model. I have two questions.
For example, I set both the max_length, max_input_length, max_output_length to 128.
-
How to deal with those long answers? I just left them as is and the T5Tokenizer can automatically handle. I would assume the tokenizer just truncates an answer at the position of 128th word (or 127th). Is it possible that I manually split an answer into different parts, each part has 128 words; and then all these sub-answers serve as a separate answer to the same question?
-
Another question is that I get incomplete (truncated) answers when using the fine-tuned model in inference, even though the predicted answer is shorter than 128 words. I found a message posted 2 years ago saying that one should add at the end of texts when fine-tuning T5. I followed that but then got a warning message that duplicated were found. I am assuming that this is because the tokenizer truncates an answer text, thus is missing in the truncated answer, such that the end token is not produced in predicted answer. However, I am not sure. Can anybody point out how to address this issue?
Any suggestions are highly appreciated.
Below is some code snippet.
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch
import numpy as np
import time
from pathlib import Path
from transformers import (
Adafactor,
T5ForConditionalGeneration,
T5Tokenizer,
get_linear_schedule_with_warmup
)
from torch.utils.data import RandomSampler
from question_answering.utils import *
class T5FineTuner(pl.LightningModule):
def __init__(self, hyparams):
super(T5FineTuner, self).__init__()
self.hyparams = hyparams
self.model = T5ForConditionalGeneration.from_pretrained(hyparams.model_name_or_path)
self.tokenizer = T5Tokenizer.from_pretrained(hyparams.tokenizer_name_or_path)
if self.hyparams.freeze_embeds:
self.freeze_embeds()
if self.hyparams.freeze_encoder:
self.freeze_params(self.model.get_encoder())
# assert_all_frozen()
self.step_count = 0
self.output_dir = Path(self.hyparams.output_dir)
n_observations_per_split = {
'train': self.hyparams.n_train,
'validation': self.hyparams.n_val,
'test': self.hyparams.n_test
}
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
self.em_score_list = []
self.subset_score_list = []
data_folder = r'C:\Datasets\MedQuAD-master'
self.train_data, self.val_data, self.test_data = load_medqa_data(data_folder)
def freeze_params(self, model):
for param in model.parameters():
param.requires_grad = False
def freeze_embeds(self):
try:
self.freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]:
self.freeze_params(d.embed_positions)
self.freeze_params(d.embed_tokens)
except AttributeError:
self.freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
self.freeze_params(d.embed_tokens)
def lmap(self, f, x):
return list(map(f, x))
def is_logger(self):
return self.trainer.proc_rank <= 0
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, labels=None):
return self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=labels
)
def _step(self, batch):
labels = batch['target_ids']
labels[labels[:, :] == self.tokenizer.pad_token_id] = -100
outputs = self(
input_ids = batch['source_ids'],
attention_mask=batch['source_mask'],
labels=labels,
decoder_attention_mask=batch['target_mask']
)
loss = outputs[0]
return loss
def ids_to_clean_text(self, generated_ids):
gen_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
return self.lmap(str.strip, gen_text)
def _generative_step(self, batch):
t0 = time.time()
generated_ids = self.model.generate(
batch["source_ids"],
attention_mask=batch["source_mask"],
use_cache=True,
decoder_attention_mask=batch['target_mask'],
max_length=128,
num_beams=2,
early_stopping=True
)
preds = self.ids_to_clean_text(generated_ids)
targets = self.ids_to_clean_text(batch["target_ids"])
gen_time = (time.time() - t0) / batch["source_ids"].shape[0]
loss = self._step(batch)
base_metrics = {'val_loss': loss}
summ_len = np.mean(self.lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=targets)
em_score, subset_match_score = calculate_scores(preds, targets)
self.em_score_list.append(em_score)
self.subset_score_list.append(subset_match_score)
em_score = torch.tensor(em_score, dtype=torch.float32)
subset_match_score = torch.tensor(subset_match_score, dtype=torch.float32)
base_metrics.update(em_score=em_score, subset_match_score=subset_match_score)
# rouge_results = self.rouge_metric.compute()
# rouge_dict = self.parse_score(rouge_results)
return base_metrics
def training_step(self, batch, batch_idx):
loss = self._step(batch)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def training_epoch_end(self, outputs):
avg_train_loss = torch.stack([x['loss'] for x in outputs]).mean()
tensorboard_logs = {'avg_train_loss': avg_train_loss}
# return {'avg_train_loss': avg_train_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}
def validation_step(self, batch, batch_idx):
return self._generative_step(batch)
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
if len(self.em_score_list) <= 2:
average_em_score = sum(self.em_score_list) / len(self.em_score_list)
average_subset_match_score = sum(self.subset_score_list) / len(self.subset_score_list)
else:
latest_em_score = self.em_score_list[:-2]
latest_subset_score = self.subset_score_list[:-2]
average_em_score = sum(latest_em_score) / len(latest_em_score)
average_subset_match_score = sum(latest_subset_score) / len(latest_subset_score)
average_em_score = torch.tensor(average_em_score, dtype=torch.float32)
average_subset_match_score = torch.tensor(average_subset_match_score, dtype=torch.float32)
tensorboard_logs.update(em_score=average_em_score, subset_match_score=average_subset_match_score)
self.target_gen = []
self.prediction_gen = []
return {
'avg_val_loss': avg_loss,
'em_score': average_em_score,
'subset_match_socre': average_subset_match_score,
'log': tensorboard_logs,
'progress_bar': tensorboard_logs
}
def configure_optimizers(self):
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hyparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = Adafactor(optimizer_grouped_parameters, lr=self.hyparams.learning_rate, scale_parameter=False,
relative_step=False)
self.opt = optimizer
return [optimizer]
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure=None,
on_tpu=False, using_native_amp=False, using_lbfgs=False):
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()
self.lr_scheduler.step()
def get_tqdm_dict(self):
tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
return tqdm_dict
def train_dataloader(self):
n_samples = self.n_obs['train']
train_dataset = get_dataset(tokenizer=self.tokenizer, data=self.train_data, num_samples=n_samples,
args=self.hyparams)
sampler = RandomSampler(train_dataset)
dataloader = DataLoader(train_dataset, sampler=sampler, batch_size=self.hyparams.train_batch_size,
drop_last=True, num_workers=4)
# t_total = (
# (len(dataloader.dataset) // (self.hyparams.train_batch_size * max(1, self.hyparams.n_gpu)))
# // self.hyparams.gradient_accumulation_steps
# * float(self.hyparams.num_train_epochs)
# )
t_total = 100000
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hyparams.warmup_steps, num_training_steps=t_total
)
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self):
n_samples = self.n_obs['validation']
validation_dataset = get_dataset(tokenizer=self.tokenizer, data=self.val_data, num_samples=n_samples,
args=self.hyparams)
sampler = RandomSampler(validation_dataset)
return DataLoader(validation_dataset, shuffle=False, batch_size=self.hyparams.eval_batch_size, sampler=sampler, num_workers=4)
def test_dataloader(self):
n_samples = self.n_obs['test']
test_dataset = get_dataset(tokenizer=self.tokenizer, data=self.test_data, num_samples=n_samples, args=self.hyparams)
return DataLoader(test_dataset, batch_size=self.hyparams.eval_batch_size, num_workers=4)
def on_save_checkpoint(self, checkpoint):
save_path = self.output_dir.joinpath("best_tfmr")
self.model.config.save_step = self.step_count
self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
import os
import argparse
import pytorch_lightning as pl
from question_answering.t5_closed_book import T5FineTuner
if __name__ == '__main__':
args_dict = dict(
output_dir="", # path to save the checkpoints
model_name_or_path='t5-large',
tokenizer_name_or_path='t5-large',
max_input_length=128,
max_output_length=128,
freeze_encoder=False,
freeze_embeds=False,
learning_rate=1e-5,
weight_decay=0.0,
adam_epsilon=1e-8,
warmup_steps=0,
train_batch_size=4,
eval_batch_size=4,
num_train_epochs=2,
gradient_accumulation_steps=10,
n_gpu=1,
resume_from_checkpoint=None,
val_check_interval=0.5,
n_val=4000,
n_train=-1,
n_test=-1,
early_stop_callback=False,
fp_16=False,
opt_level='O1',
max_grad_norm=1.0,
seed=101,
)
args_dict.update({'output_dir': 't5_large_MedQuAD_256', 'num_train_epochs': 100,
'train_batch_size': 16, 'eval_batch_size': 16, 'learning_rate': 1e-3})
args = argparse.Namespace(**args_dict)
checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=args.output_dir, monitor="em_score", mode="max", save_top_k=1)
## If resuming from checkpoint, add an arg resume_from_checkpoint
train_params = dict(
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.n_gpu,
max_epochs=args.num_train_epochs,
# early_stop_callback=False,
precision=16 if args.fp_16 else 32,
# amp_level=args.opt_level,
# resume_from_checkpoint=args.resume_from_checkpoint,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
val_check_interval=args.val_check_interval,
# accelerator='dp'
# logger=wandb_logger,
# callbacks=[LoggingCallback()],
)
model = T5FineTuner(args)
trainer = pl.Trainer(**train_params)
trainer.fit(model)