Metadata in batches

I am trying to train a very simple seq-2-seq model, with training data composed of the usual input_ids, labels and attention_mask plus an additional column that allows me to go from the batch to the original dataset (example_id).

In order to associate the evaluation metrics to the original raw dataset, I have added an “example_id” that allows me to study the behaviour of the model on specific classes of input data points. I keep this column in the batch using this custom collation:

class MyCollatorForSeq2Seq(DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):
        example_ids = [feature["example_id"] for feature in features] 

        features = [{k: v for k, v in feature.items() if k != "example_id"} for feature in features]

        # standard collator
        batch = super().__call__(features, return_tensors="pt")

        assert type(example_ids[0]) == int, "problem with example ids in batching"
        batch["example_id"] = example_ids
        
        return batch

When I pass batch made as described above to a standard Seq2SeqTrainer, I get an error related to an unexpected key in the batch:

T5ForConditionalGeneration.forward() got an unexpected keyword argument 'example_id'

Is there a simple way to make the Trainer accept metadata fields in the batch or the only way is to subclass the trainer?

1 Like