Unpacking transformer's trainer.eval() to see every example's output, loss

I’m sure this is an FAQ, but I’ve not found the answer anywhere:

I have a transformer trained with BertForMultipleChoice model, and
trainer.eval() produces summary statistic metrics. I want to get
per-example results over the eval_dataset.

I’ve tried something like this, invoking the model’s tokenizer on the sample text:

for inp in eval_dataset:
    id = inp['example_id']
    lbl = inp['label']
    txt = inp['citing_prompt']
    toks = tokenizer.tokenize(txt)
    outs = model(toks)
    print(id,lbl,outs)

but this complains because I’m providing a list of tokens vs what the model expects:

outs = model(toks)
       ^^^^^^^^^^^

File “…/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “…/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “…/transformers/models/bert/modeling_bert.py”, line 1662, in forward
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

what additional work (beyond tokenizing) do I need to reproduce the model’s embedding prior to the encoded input the model expects? I am also interested in capturing each example’s encoding (eg, to look at similarities among them).

or, is there a way/place to tap into the transformer’s eval to get the per-example stats I need?

thanks for any hints!

The place within trainer.eval() that come closest to the nexus I was seeking seems to ~ transformers.trainer.Trainer.compute_loss() (l#3059).

But for now my self-answer seems to be using the captum.ai utilities.

I’m pretty new to NLP so please forgive me if I’m wrong but I think the problem is that you’re not really give the model what it wants. My impression is that the Trainer class expects data in the form of the output of tokenizer(text_data) and data in the form of tokenizer.tokenize(text_data)

The printed results under tokenizer(text) aren’t actually the direct outputs, I converted the output with the dict function before printing

image

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.