Hello,
I would like to log text generated during training with the Trainer
class to my Tensorboard. I’m looking into the TensorBoardCallback
class, but it seems like I can’t access the model outputs easily. I came up with a solution but it seems quite hacky:
class CustomTrainer(Seq2SeqTrainer):
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
self.state.logits = outputs['logits']
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
which I then call when overriding on_log
:
class CustomCallback(TensorBoardCallback):
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs=None,
**kwargs,
):
if not state.is_world_process_zero:
return
logits = state.logits
preds = torch.argmax(logits, axis=-1)
idx = random.randint(0, logits.shape[0]-1)
pred_text = kwargs['tokenizer'].batch_decode(preds, skip_special_tokens=True)[idx]
del state.logits
if self.tb_writer is None:
self._init_summary_writer(args)
if self.tb_writer is not None:
self.tb_writer.add_text('preds', pred_text, global_step=state.global_step)
logs = rewrite_logs(logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
else:
print(
"Trainer is attempting to log a value of "
f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute."
)
self.tb_writer.flush()
is there another way to retrieve outputs from my model within on_log
?