def finetune():
progress = widgets.FloatProgress(value=0.1, min=0.0, max=1.0, bar_style = ‘info’)
block_size = tokenizer.max_len
train_dataset = TextDataset(tokenizer=tokenizer, file_path=’/content/Armenian Para.txt’, block_size=block_size)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
epochs = 2.5 # limit before overfitting
training_args = TrainingArguments(
output_dir= '/content/ParaCorpus',
overwrite_output_dir=True,
do_train=True,
num_train_epochs=2.5,
per_gpu_train_batch_size=1,
prediction_loss_only=True,
logging_steps=5,
save_steps=0,
seed=random.randint(0,2**32-1))
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,)
progress.value = 0.4
p_start, p_end = 0.4, 1.
def progressify(f):
def inner(*args, **kwargs):
if trainer.epoch is not None:
progress.value = p_start + trainer.epoch / epochs * (p_end - p_start)
return f(*args, **kwargs)
return inner
try:
trainer.training_step = progressify(trainer.training_step)
trainer.train()
except KeyboardInterrupt:
print('Model will be saved')
finally:
trainer.save_model('/content/ParaCorpus')
tokenizer.save_pretrained('/content/ParaCorpus')
def clean_prediction(text):
token = ‘<|endoftext|>’
while len(token)>1:
text = text.replace(token, '')
token = token[:-1]
text = text.strip()
if len(text) != 0:
if text[-1] == '"' and text.count('"') % 2: text = text[:-1]
return text.strip()
finetune()