The TrainerState's log_history is always empty when using a custom callback

I’m trying to use the Trainer API with a custom MLflowCallback object to log my metrics and artifacts to the AWS S3 artifact storage I have. The custom callback looks like this:

import logging
import os

import mlflow
from transformers import TrainerCallback


class MLflowCallback(TrainerCallback):
    def __init__(
        self,
        logger: logging.Logger,
        mlflow_uri: str,
        experiment_name: str,
        log_file: str,
        log_artifact_path: str = "logs",
        model_artifact_path: str = "models",
    ):
        self.logger = logger

        self.mlflow_uri = mlflow_uri
        self.experiment_name = experiment_name

        self.log_file = log_file
        self.log_artifact_path = log_artifact_path
        self.model_artifact_path = model_artifact_path

        self.logger.info("Setting MLflow URI to %s", mlflow_uri)
        mlflow.set_tracking_uri(mlflow_uri)

        self.logger.info("Setting MLflow experiment to %s", experiment_name)
        mlflow.set_experiment(experiment_name)

    def on_train_begin(self, args, state, control, **kwargs):
        run_name = os.path.basename(args.run_name)
        self.logger.info("Starting MLflow run with name %s", run_name)

        if mlflow.active_run():
            mlflow.end_run()

        mlflow.start_run(run_name=run_name)
        mlflow.log_params(vars(args))

    def on_step_end(self, args, state, control, **kwargs):
        if state.log_history:
            mlflow.log_metrics(
                {"training_loss": state.log_history[-1]["loss"]},
                step=state.global_step,
            )
        else:
            self.logger.warning("Skipping logging training loss at epoch %.4f because log_history is empty.", state.epoch)

        mlflow.log_artifact(
            local_path=self.log_file,
            artifact_path=self.model_artifact_path,
        )

    def on_evaluate(self, args, state, control, metrics, **kwargs):
        mlflow.log_metrics(metrics, step=state.global_step)

        if state.is_best:
            mlflow.log_artifacts(
                local_path=args.output_dir,
                artifact_path=self.model_artifact_path,
            )

    def on_train_end(self, args, state, control, **kwargs):
        mlflow.end_run()

In the on_step_end method I put in an if-else statement to catch empty log histories. However, the log_history in the TrainerState is always empty.

Is there something that I’m unintentionally overriding with my custom callback?