Trainer log my custom metrics at training step

Hey there.
I’m using the Huggingface Trainer to finetune my model, and use tensorboard to display the mertics.
I find that the trainer only logs the train_loss which is return by the model_ouput. However, I wonder if there is a way for me to have more information logged during the train_step, such as my own loss which is part the trian_loss.
I check the trainer code

    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
        if self.control.should_log:
            if is_torch_tpu_available():
                xm.mark_step()

            logs: Dict[str, float] = {}

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

            # reset tr_loss to zero
            tr_loss -= tr_loss

            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
            logs["learning_rate"] = self._get_learning_rate()

            self._total_loss_scalar += tr_loss_scalar
            self._globalstep_last_logged = self.state.global_step
            self.store_flos()

            self.log(logs)

I think this code has defined the log metrics which I may not have the change to customize my own metric during the training step.
I would like to know if there are some ways to log.

Cheers & Thanks for an tips!

3 Likes

I have the same problem. Have u solve it ?

If you don’t use gradient accumulation, then I usually just hack by overwriting Trainer.compute_loss and tucking in one line of self.log(compute_my_metric(output)

If you use gradient accumulation, one alternative is to trigger a CustomCallback per Metrics for Training Set in Trainer - #7 by Kaveri. For example, you can do one forward pass on the entire train set on_epoch_end or on_evaluate. It will be repeated work, slow and coarse.

And let me know if you figured out an easy way to log custom loss!

2 Likes

https://github.com/zipzou/hf-multitask-trainer
You can try to use this to track different losses.