I am trying to train a very simple seq-2-seq model, with training data composed of the usual input_ids
, labels
and attention_mask
plus an additional column that allows me to go from the batch to the original dataset (example_id
).
In order to associate the evaluation metrics to the original raw dataset, I have added an “example_id” that allows me to study the behaviour of the model on specific classes of input data points. I keep this column in the batch using this custom collation:
class MyCollatorForSeq2Seq(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
example_ids = [feature["example_id"] for feature in features]
features = [{k: v for k, v in feature.items() if k != "example_id"} for feature in features]
# standard collator
batch = super().__call__(features, return_tensors="pt")
assert type(example_ids[0]) == int, "problem with example ids in batching"
batch["example_id"] = example_ids
return batch
When I pass batch made as described above to a standard Seq2SeqTrainer, I get an error related to an unexpected key in the batch:
T5ForConditionalGeneration.forward() got an unexpected keyword argument 'example_id'
Is there a simple way to make the Trainer accept metadata fields in the batch or the only way is to subclass the trainer?