How can I disable log history from getting printed every logging_steps

I’m writing a custom ProgressCallback that modifies the original ProgressCallback transformers implementation and adds some additional information/data to the tqdm progress bar. Here’s what I have so far, and it works nicely and as intended.

class ProgressCallback(TrainerCallback):
    """A [`TrainerCallback`] that displays the progress of training or evaluation.

    Specifically, it shows:
    1. Time spent so far in training or evaluation.
    2. Estimated time remaining for training or evaluation.
    3. Iterations per second.
    4. Loss.
    5. Number of input tokens seen so far.
    """

    def __init__(self):
        self.training_bar = None
        self.prediction_bar = None
        self.current_step: int = 0
        self.loss: float = math.nan
        self.num_input_tokens_seen = format_number_suffix(0)

    def on_train_begin(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True)

    def on_step_end(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            self.training_bar.update(state.global_step - self.current_step)
            self.current_step = state.global_step

    def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
        if state.is_world_process_zero and has_length(eval_dataloader):
            if self.prediction_bar is None:
                self.prediction_bar = tqdm(
                    total=len(eval_dataloader),
                    leave=self.training_bar is None,
                    dynamic_ncols=True,
                )
            self.prediction_bar.update(1)

    def on_evaluate(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            if self.prediction_bar is not None:
                self.prediction_bar.close()
            self.prediction_bar = None

    def on_predict(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            if self.prediction_bar is not None:
                self.prediction_bar.close()
            self.prediction_bar = None

    def on_log(self, args, state, control, logs=None, **kwargs):
        if state.is_world_process_zero and self.training_bar is not None:
            # The last callback_handler.on_log() call in the training loop logs `train_loss` as opposed to `loss`.
            # From some digging through transformers code, the `train_loss` is the average training loss
            # during training.
            # See: https://github.com/huggingface/transformers/blob/v4.27.2/src/transformers/trainer.py#L2025-L2026
            self.loss = (
                state.log_history[-1]["loss"]
                if state.log_history and "loss" in state.log_history[-1]
                else state.log_history[-1]["train_loss"]
            )
            self.num_input_tokens_seen = format_number_suffix(state.num_input_tokens_seen)
            self.training_bar.set_postfix_str(
                f"loss: {self.loss:.4f}, tokens: {self.num_input_tokens_seen}",
            )

    def on_train_end(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            self.training_bar.close()
            self.training_bar = None

In my trainer arguments, I explicitly disable_tdqm so I can pass this as a custom callback in place of the original ProgressCallback. I also set logging_steps to 1 so that I can get metrics back from every step through the log_history attribute in the TrainerState object.

The challenge I’m having is that it logs the metric to stdout, but I am not sure where that actually comes from in the code. I don’t want that behavior since I want to surface relevant information directly in my TQDM progress back through my callback. Looking at the transformers trainer, I’ve narrowed down that metrics get pass to on_log in the callback, and that seems to happen from within this function at the end of each step of training and then again at the end of training: transformers/src/transformers/trainer.py at v4.27.2 · huggingface/transformers · GitHub

When I set a breakpoint at the end of on_log in my callback, I can confirm that the logs object doesn’t get printed to stdout. So it happens somewhere between that and this looping to get to the next train step, but not sure if I am missing something obvious since I’m still new to the transformers codebase.

Here’s what I see in my output:

***** Running training *****
  Num examples = 183
  Num Epochs = 3
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 16
  Total optimization steps = 33
  Number of trainable parameters = 256
  3%|██▍                                                                               | 1/33 [00:01<00:34,  1.07s/it, loss: 10.3748, tokens: 16.38K]{'loss': 10.3748, 'learning_rate': 0.00019393939393939395, 'epoch': 0.09, 'num_input_tokens_seen': 16384}
  6%|████▉                                                                             | 2/33 [00:01<00:22,  1.39it/s, loss: 10.3741, tokens: 32.77K]{'loss': 10.3741, 'learning_rate': 0.0001878787878787879, 'epoch': 0.17, 'num_input_tokens_seen': 32768}
  9%|███████▍                                                                          | 3/33 [00:02<00:18,  1.66it/s, loss: 10.3737, tokens: 49.15K]{'loss': 10.3737, 'learning_rate': 0.00018181818181818183, 'epoch': 0.26, 'num_input_tokens_seen': 49152}
 12%|█████████▉                                                                        | 4/33 [00:02<00:15,  1.83it/s, loss: 10.3748, tokens: 65.54K]{'loss': 10.3748, 'learning_rate': 0.00017575757575757578, 'epoch': 0.35, 'num_input_tokens_seen': 65536}
 15%|████████████▍                                                                     | 5/33 [00:02<00:14,  1.93it/s, loss: 10.3729, tokens: 81.92K]{'loss': 10.3729, 'learning_rate': 0.00016969696969696972, 'epoch': 0.44, 'num_input_tokens_seen': 81920}

Here’s what I want to see, but can’t figure out where the log_history/logs get printed in the training loop

***** Running training *****
  Num examples = 183
  Num Epochs = 3
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 16
  Total optimization steps = 33
  Number of trainable parameters = 256
 15%|████████████▍                                                                     | 5/33 [00:02<00:14,  1.93it/s, loss: 10.3729, tokens: 81.92K]

Any help would be greatly appreciated!