With very little code, I receive errors from deep within the HF-framework that I dont understand. I also cannot find documentation or examples for the use of BertForPreTraining together with the Trainer. Maybe someone here sees what I am doing wrong. Many thanks in advance.
I want to pretrain BERT for MLM with the data in the text_0.txt (toy data), therefore I have to use the TextDatasetForNextSentencePrediction because of the NSP objective of BERT.
text_0.txt:
I am very happy.
Here is the second sentence.
A new document.
This is a test.
What is going on here?
minimal script:
from transformers import BertModel, BertConfig, BertTokenizer, BertForPreTraining
from transformers import DataCollatorForLanguageModeling
from transformers.data.datasets.language_modeling import TextDatasetForNextSentencePrediction
from transformers import Trainer, TrainingArguments
model = BertForPreTraining(BertConfig())
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True
)
dataset = TextDatasetForNextSentencePrediction(
tokenizer=tokenizer,
file_path="./training/data/test/text_0.txt",
block_size=10,
overwrite_cache=True,
short_seq_probability=0.1,
nsp_probability=0.5
)
# this shows that the dataset has been collected
for i in range(len(dataset)):
print(dataset[i])
args = TrainingArguments(
output_dir="./test_results",
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=30,
save_steps=100,
save_total_limit=5,
)
trainer = Trainer(
model=model,
args=args,
data_collator=data_collator,
train_dataset=dataset
)
trainer.train()
But when I run this minimal example, it leads to following stack trace:
You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it.
<transformers.trainer.Trainer object at 0x0000026569E3FAF0>
Epoch: 0%| | 0/1 [00:00<?, ?it/s]T
raceback (most recent call last): | 0/1 [00:00<?, ?it/s]
File "test_data_collator2.py", line 47, in <module>
trainer.train()
File "C:\Users\pharn\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\trans
formers\trainer.py", line 755, in train
for step, inputs in enumerate(epoch_iterator):
File "C:\Users\pharn\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch
\utils\data\dataloader.py", line 517, in __next__
data = self._next_data()
File "C:\Users\pharn\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch
\utils\data\dataloader.py", line 557, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "C:\Users\pharn\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch
\utils\data\_utils\fetch.py", line 47, in fetch
return self.collate_fn(data)
File "C:\Users\pharn\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\trans
formers\data\data_collator.py", line 135, in __call__
examples = [e["input_ids"] for e in examples]
File "C:\Users\pharn\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\trans
formers\data\data_collator.py", line 135, in <listcomp>
examples = [e["input_ids"] for e in examples]
KeyError: 'input_ids'
Epoch: 0%| | 0/1 [00:00<?, ?it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]