Logging text using model outputs with tensorboard


I would like to log text generated during training with the Trainer class to my Tensorboard. I’m looking into the TensorBoardCallback class, but it seems like I can’t access the model outputs easily. I came up with a solution but it seems quite hacky:

class CustomTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        self.state.logits = outputs['logits']
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        return (loss, outputs) if return_outputs else loss

which I then call when overriding on_log:

class CustomCallback(TensorBoardCallback):
    def on_log(
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        if not state.is_world_process_zero:
        logits = state.logits
        preds = torch.argmax(logits, axis=-1)
        idx = random.randint(0, logits.shape[0]-1)
        pred_text = kwargs['tokenizer'].batch_decode(preds, skip_special_tokens=True)[idx]
        del state.logits

        if self.tb_writer is None:

        if self.tb_writer is not None:
            self.tb_writer.add_text('preds', pred_text, global_step=state.global_step)
            logs = rewrite_logs(logs)
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                        "Trainer is attempting to log a value of "
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute."

is there another way to retrieve outputs from my model within on_log?