Hi all,
I’ve spent a couple days trying to get this to work. I’m trying to pretrain BERT from scratch using the standard MLM approach. I’m pretraining since my input is not a natural language per se.
Here is my code:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import BertProcessing
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Split
from tokenizers.normalizers import Strip
from tokenizers import Regex
exp = Regex("(^((\w)+(?=\s)))|((\[ENTRY\]\ (\w|\||\.)+)\s)|((\[CALL\]\ (\w|\||\.|\s)+)(?=\ \[))+|(\[EXIT\])")
pre_tokenizer = Split(pattern=exp, behavior="removed",invert=True)
#print(pre_tokenizer.pre_tokenize_str("performExpensiveLogSetup [ENTRY] void [CALL] java.io.PrintStream println java.lang.String void [CALL] java.lang.Math pow double|double double [CALL] java.lang.Math sqrt double double [CALL] java.io.PrintStream println java.lang.String void [EXIT]"))
trace_tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
trace_tokenizer.add_special_tokens(["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
trace_tokenizer.normalizer = Strip()
trace_tokenizer.pre_tokenizer = pre_tokenizer
trace_tokenizer.post_processor = BertProcessing(sep=("[SEP]", 0),cls=("[CLS]", 1))
VOCAB_SIZE = 5000
trace_tokenizer.add_special_tokens(['[PAD]'])
trace_tokenizer.add_tokens([' '])
trainer = WordLevelTrainer(
vocab_size=VOCAB_SIZE, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
files = ["10k_smaller_dataset.txt"]
trace_tokenizer.train(files, trainer)
trace_tokenizer.save("data/trace.json")
from transformers import BertConfig, BertForMaskedLM
scale_factor = 0.25
config = BertConfig(
vocab_size=VOCAB_SIZE,
max_position_embeddings=int(768*scale_factor),
intermediate_size=int(2048*scale_factor),
hidden_size=int(512*scale_factor),
num_attention_heads=8,
num_hidden_layers=6,
type_vocab_size=5,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
)
from transformers import PreTrainedTokenizerFast
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=trace_tokenizer,
return_special_tokens_mask=True, mask_token='[MASK]', return_token_type_ids=False)
fast_tokenizer.add_special_tokens({'pad_token': '[PAD]', 'mask_token': '[MASK]'})
from datasets import load_dataset
dataset = load_dataset('text', data_files={'train': '10k_smaller_dataset.txt', 'test': 'tiny_eval.txt', 'eval': 'tiny_eval.txt'})
small_train_dataset = dataset["train"]
small_eval_dataset = dataset["test"]
model = BertForMaskedLM(config)
model.tokenizer = fast_tokenizer
def preprocess_function(examples):
return fast_tokenizer(examples["text"], max_length = 128, truncation=True, padding=True)
encoded_dataset_train = small_train_dataset.map(preprocess_function, batched=True)
encoded_dataset_test = small_eval_dataset.map(preprocess_function, batched=True)
import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metric(eval_pred):
return metric.compute(predictions=eval_pred.predictions, references=eval_pred.label_ids)
from transformers import TrainingArguments
from transformers import DataCollatorForWholeWordMask
data_collator = DataCollatorForWholeWordMask(tokenizer=fast_tokenizer, mlm=True, mlm_probability=0.15)
training_args = TrainingArguments("test_trainer_bert_pre",
num_train_epochs=1,
# prediction_loss_only=True,
)
from transformers import Trainer
trainer = Trainer(
model=model,
tokenizer=fast_tokenizer,
data_collator=data_collator,
args=training_args, train_dataset=encoded_dataset_train,
eval_dataset=encoded_dataset_test,
compute_metrics=compute_metric,
)
train_result = trainer.train(resume_from_checkpoint=True)
train_result
trainer.evaluate(encoded_dataset_test)
The problem is in the last line, I never see the accuracy metric I define.
{'epoch': 1.0,
'eval_loss': 0.0025006113573908806,
'eval_runtime': 1.9859,
'eval_samples_per_second': 503.54,
'eval_steps_per_second': 62.942}
I’m sure there’s a super simple mistake I’m making that’s resulting in it being ignored. Any ideas?
Thank you in advance.
Best,
Claudio