Couple of questions about Trainer

I’m trying to train a BertForMaskLM model from scratch for a domain-specific vocab. I’ve trained a whole word tokenizer that seems to work well, but I see some issues with my trainer.

Here is the code for my trainer and my questions:

  1. It doesn’t look like that compute_metric is getting called
  2. The validation loss is always nan? Please explain why.
  3. Shouldn’t be able to see the metric (in this case, accuracy) at each step checkpoint
  4. When calling trainer.predict(test_dataset) the EvalPrediction items are all null.
metrics = evaluate.load('accuracy')

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.argmax(preds, axis=-1)
    result = metrics.compute(predictions=preds, references=p.label_ids)
    return result

batch_size = 32

config = BertConfig(vocab_size=len(tokenizer),
                    output_hidden_states=True)
bert_model = BertForMaskedLM(config)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

training_args = TrainingArguments(
    output_dir="./bert-checkpoints",
    overwrite_output_dir=True,
    num_train_epochs=2,
    auto_find_batch_size=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_steps=0,
    save_total_limit=2,
    prediction_loss_only=True,
    evaluation_strategy = 'steps',
    logging_steps=1000,
    eval_steps=1000,
    do_train=True,
    do_eval=True
)

trainer = Trainer(
    model=bert_model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics
)

trainer.train()