Hello, I am fine-tuning Pegasus on a summarization task and want to integrate a domain adaptation script into the training, which would require me to separate out the encoder and decoder objects of the PegasusForConditionalGeneration Class. To this end, this is the code I have so far:
model = SummarizerNet(model_name, args)
encoder = model.encoder
decoder = model.decoder
lm_head = model.lm_head
vocab_size = model.vocab_size
if torch.cuda.device_count()>1:
model = nn.DataParallel(model)
model.to(args.device)
src_loader = DataLoader(src_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=1, pin_memory=True, drop_last=True)
n_batches = len(src_loader)
optim = torch.optim.AdamW(list(model.parameters()), weight_decay=0.01)
summarizer_criterion = nn.CrossEntropyLoss()
for epoch in range(args.num_epochs):
model.train()
total_summarization_loss = 0
for src_batch in tqdm(src_loader, leave=False, total=n_batches):
# unroll all items from batch dictionary
src_input_ids, src_attention_mask, src_decoder_input_ids, src_decoder_attention_mask, src_labels = (v.to(args.device) for _,v in src_batch.items())
# shift the decoder input_ids by 1
if torch.cuda.device_count()>1:
src_decoder_input_ids = shift_tokens_right(src_labels, model.module.config.pad_token_id, model.module.config.decoder_start_token_id)
else:
src_decoder_input_ids = shift_tokens_right(src_labels, model.config.pad_token_id, model.config.decoder_start_token_id)
# pass the source domain documents through the encoder
encoder_features_src = encoder(input_ids=src_input_ids, attention_mask=src_attention_mask).last_hidden_state
# pass the encoder hidden state and the source reference summaries to the decoder
decoder_outputs = decoder(input_ids=src_decoder_input_ids, attention_mask = src_decoder_attention_mask,
encoder_hidden_states = encoder_features_src, encoder_attention_mask = src_attention_mask).last_hidden_state
# create the final logits bias (copied from src/transformers/modeling_pegasus.py)
if torch.cuda.device_count()>1:
final_logits_bias = model.module.model.final_logits_bias
else:
final_logits_bias = model.model.final_logits_bias
lm_head_output = lm_head(decoder_outputs) + final_logits_bias
loss = summarizer_criterion(lm_head_output.reshape(-1, vocab_size), src_labels.reshape(-1))
optim.zero_grad()
loss.backward()
optim.step()
total_summarization_loss += loss.item()
mean_summarization_loss = total_summarization_loss / n_batches
tqdm.write(f'EPOCH {epoch+1:03d}:'
f'summarization_loss={mean_summarization_loss:.4f}')
save_dir = '/'.join(save_path.split('/')[:-1])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if torch.cuda.device_count() > 1:
model_checkpoint = {'model_state_dict': model.module.model.state_dict(), 'optimizer': optim.state_dict()}
else:
model_checkpoint = {'model_state_dict': model.model.state_dict(), 'optimizer': optim.state_dict()}
torch.save(model_checkpoint, save_path)
The helper classes and functions I have used are as follows:
class SummarizerNet(nn.Module):
def __init__(self, model_name, args):
super().__init__()
self.model = PegasusForConditionalGeneration.from_pretrained(model_name)
self.config = self.model.config
self.encoder = self.model.get_encoder()
self.decoder = self.model.get_decoder()
self.lm_head = self.model.lm_head
self.vocab_size = self.model.config.vocab_size
self.model.post_init()
def forward(self, x):
features = self.encoder(x)
logits = self.decoder(features)
return logits
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
def prepare_data(model_name,
train_texts, train_labels,
val_texts=None, val_labels=None,
test_texts=None, test_labels=None):
"""
Prepare input data for model fine-tuning
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
prepare_val = False if val_texts is None or val_labels is None else True
prepare_test = False if test_texts is None or test_labels is None else True
train_dataset = tokenize_data(train_texts, train_labels, tokenizer)
val_dataset = tokenize_data(val_texts, val_labels, tokenizer) if prepare_val else None
test_dataset = tokenize_data(test_texts, test_labels, tokenizer) if prepare_test else None
return train_dataset, val_dataset, test_dataset, tokenizer
def tokenize_data(texts, labels, tokenizer):
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length = 512)
decodings = tokenizer(labels, truncation=True, padding='max_length', max_length = 256)
dataset_tokenized = PegasusDatasetDA(encodings, decodings, tokenizer)
return dataset_tokenized
The issue I am facing is that my model is not generating as coherent summaries as the un-split model (ie, passing through the PegasusForConditionalGeneration Class itself). Where am I messing up?