I’m new. I’m trying to fine-tuned a BERT MLM (bert-base-uncased) on a target domain. Unfortunately, results are not good.
Before fine-tuning, the pre-trained model fills the mask of a sentence with words in line of human expectations.
E.g. Wikipedia is a free online [MASK], created and edited by volunteers around the world.
The most probable prediction are encyclopedia (score: 0.650) and resource (score:0.087).
After fine-tuning, the prediction are completely wrong. Often stopwords are predicted as result.
E.g. Wikipedia is a free online [MASK], created and edited by volunteers around the world.
The most probable prediction are the (score: 0.052) and be (score:0.033).
I experimented with different epochs (from 1 to 10) and different datasets (from a few MB to a few GB) but I got the same issue. What am I doing wrong? I’m using the following code, I hope you can help me.
from transformers import AutoConfig, AutoTokenizer, AutoModelForMaskedLM
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
config = AutoConfig.from_pretrained('bert-base-uncased', output_hidden_states=True)
model = AutoModelForMaskedLM.from_config(config) # BertForMaskedLM.from_pretrained(path)
from transformers import LineByLineTextDataset
dataset = LineByLineTextDataset(tokenizer=tokenizer,
file_path="data/english/corpora.txt",
block_size = 512)
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(output_dir="output/models/english",
overwrite_output_dir=True,
num_train_epochs=5,
per_gpu_train_batch_size=8,
save_steps = 22222222,
save_total_limit=2)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator, train_dataset=dataset)
trainer.train()
trainer.save_model("output/models/english")
from transformers import pipeline
# Initialize MLM pipeline
mlm = pipeline('fill-mask', model="output/models/english", tokenizer="output/models/english")
# Get mask token
mask = mlm.tokenizer.mask_token
# Get result for particular masked phrase
phrase = f'Wikipedia is a free online {mask}, created and edited by volunteers around the world'
result = mlm(phrase)
# Print result
print(result)