@valhalla Here is the full code snippet:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TrainingArguments, Trainer
import torch
from torch.utils.data import Dataset
import sys
import pandas as pd
#import numpy as np
ZERO = sys.float_info.min
ZERO_PT = torch.tensor(ZERO)
class GPT2FinetunedWithNgrams(GPT2LMHeadModel):
def __init__(self, config, model_tokenizer=None):
super().__init__(config)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', padding_side='right')
self.tokenizer.pad_token = self.tokenizer.eos_token
def eval_sentence(self, sent: str):
vec = self.sentence_vec(
sent) # remove punct, lower case, split on space, prepend "<s>", postpend "</s>" start and stop tokens. Returns list of strings.
last_idx = min(self.max_ngram, len(vec))
log_prob = 0
for i in range(2, last_idx + 1):
#log_prob += np.log(max(ZERO, self.pkatz(vec[0:i]))) # conditional probability with katz backoff
log_prob += torch.log(max(ZERO_PT, self.pkatz(vec[0:i])))
for i in range(1, len(vec) - last_idx + 1):
j = i + last_idx
#log_prob += np.log(max(ZERO, self.pkatz(vec[i:j])))
log_prob += torch.log(max(ZERO_PT, self.pkatz(vec[i:j])))
return log_prob, len(vec)
def sentence_loss(self, sent: str):
p, l = self.eval_sentence(sent)
return -p
def generate_text_while_finetuning(self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None, ):
transformer_outputs = self.transformer(
input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:]
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
def forward(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=True,
):
max_length = input_ids.shape[1] + 50
full_generated_gpt2_ids = self.generate(input_ids=input_ids,
max_length=max_length,
is_finetuning_current_model=True,
attention_mask=attention_mask,
pad_token_id=50256,
do_sample=True,
top_k=50,
top_p=0.95)
decoded_gen_samples = self.tokenizer.batch_decode(full_generated_gpt2_ids, skip_special_tokens=True)
tmp_losses = [self.sentence_loss(decoded_sample) for decoded_sample in decoded_gen_samples]
losses = torch.stack(tmp_losses)
loss = losses.mean()
loss.requires_grad = True
return (loss,)
##The code below is the run script.
class MyDataset(Dataset):
def __init__(self, csv_file: str):
self.df = pd.read_csv(csv_file, encoding='ISO-8859-1')
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
text = self.df.iloc[idx, 1]
return text
def my_data_collator(dataset_samples_list):
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', padding_side='right')
tokenizer.pad_token = tokenizer.eos_token
encoded_results = tokenizer(dataset_samples_list, padding=True, truncation=True, return_tensors='pt', return_attention_mask=True)
batch = {}
batch['input_ids'] = torch.stack([result for result in encoded_results['input_ids']])
batch['past'] = None
batch['attention_mask'] = torch.stack([result for result in encoded_results['attention_mask']])
batch['position_ids'] = None
batch['head_mask'] = None
batch['inputs_embeds'] = None
batch['labels'] = None
batch['use_cache'] = True
return batch
dataset_train = MyDataset('/path/to/train_dataset.csv')
training_args = TrainingArguments(
output_dir='/path/to/out',
do_train=True,
per_device_train_batch_size=64,
logging_dir='/path/to/dir',
max_steps=300000
)
model = GPT2FinetunedWithNgrams.from_pretrained('gpt2')
trainer = Trainer(
model=model,
args=training_args,
data_collator=my_data_collator,
train_dataset=dataset_train
)
trainer.train()
trainer.save_model('/path/to/model_save_dir')