I know this is an old issue, but I came across this while trying to determine the best way to track metrics besides the loss during training. I thought I’d post what I came up with in case it helps someone else.
To reiterate the context, like @Bumblebert, I’m interested in running additional metrics on the outputs that the model already computes during training, rather than running an additional evaluation run over the entire training set (using, e.g.,
self.evaluate(self.train_dataset)). My use case is that I’m training a multiple choice model and I’d like to see how the accuracy changes while training.
I’ve found the suggestion in the
Trainer class to “Subclass and override for custom behavior.” to be a good idea a couple of times now To compute custom metrics, I found where the outputs were easily accessible, in
compute_loss(), and added some code. I’ve prefixed
MAX: to my comments below:
def compute_loss(self, model, inputs, return_outputs=False):
MAX: Subclassed to compute training accuracy.
How the loss is computed by Trainer. By default, all models return the loss in
the first element.
Subclass and override for custom behavior.
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
labels = None
outputs = model(**inputs)
# MAX: Start of new stuff.
if "labels" in inputs:
preds = outputs.logits.detach()
acc = (
(preds.argmax(axis=1) == inputs["labels"])
# MAX: End of new stuff.
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
loss = self.label_smoother(outputs, labels)
# We don't use .loss here since the model may return tuples instead of
loss = outputs["loss"] if isinstance(outputs, dict) else outputs
return (loss, outputs) if return_outputs else loss
Then, I instantiate a
CustomTrainer instead of a
Trainer and run as normal.
(Note that the above code isn’t battle-tested, and I only tried on a single GPU. So take it as a starting point and with a grain of salt.)
I started using the wandb plotting integration, which is sent the results of
self.log() that we added, and automatically makes a plot:
(my runs were called