I’m trying to use the Trainer API with a custom MLflowCallback object to log my metrics and artifacts to the AWS S3 artifact storage I have. The custom callback looks like this:
import logging
import os
import mlflow
from transformers import TrainerCallback
class MLflowCallback(TrainerCallback):
def __init__(
self,
logger: logging.Logger,
mlflow_uri: str,
experiment_name: str,
log_file: str,
log_artifact_path: str = "logs",
model_artifact_path: str = "models",
):
self.logger = logger
self.mlflow_uri = mlflow_uri
self.experiment_name = experiment_name
self.log_file = log_file
self.log_artifact_path = log_artifact_path
self.model_artifact_path = model_artifact_path
self.logger.info("Setting MLflow URI to %s", mlflow_uri)
mlflow.set_tracking_uri(mlflow_uri)
self.logger.info("Setting MLflow experiment to %s", experiment_name)
mlflow.set_experiment(experiment_name)
def on_train_begin(self, args, state, control, **kwargs):
run_name = os.path.basename(args.run_name)
self.logger.info("Starting MLflow run with name %s", run_name)
if mlflow.active_run():
mlflow.end_run()
mlflow.start_run(run_name=run_name)
mlflow.log_params(vars(args))
def on_step_end(self, args, state, control, **kwargs):
if state.log_history:
mlflow.log_metrics(
{"training_loss": state.log_history[-1]["loss"]},
step=state.global_step,
)
else:
self.logger.warning("Skipping logging training loss at epoch %.4f because log_history is empty.", state.epoch)
mlflow.log_artifact(
local_path=self.log_file,
artifact_path=self.model_artifact_path,
)
def on_evaluate(self, args, state, control, metrics, **kwargs):
mlflow.log_metrics(metrics, step=state.global_step)
if state.is_best:
mlflow.log_artifacts(
local_path=args.output_dir,
artifact_path=self.model_artifact_path,
)
def on_train_end(self, args, state, control, **kwargs):
mlflow.end_run()
In the on_step_end
method I put in an if-else statement to catch empty log histories. However, the log_history
in the TrainerState is always empty.
Is there something that I’m unintentionally overriding with my custom callback?