In @sgugger 's example notebook Token Classification
DataCollatorForTokenClassification is used:
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer)
trainer = Trainer(
I am trying to figure out what the real purpose of this is. It appears that the purpose of
DataCollatorForTokenClassification is for padding, truncation, etc. But you can also do this in the tokenizer. Why do we need this extra thing, then? Is it because DataCollator does it per batch instead on the fly and is more efficient?
hey @hamel, welcome to the forum!
you’re spot on about using data collators to do padding on-the-fly. to understand why this helps, consider the following scenarios:
- use the tokenizer to pad each example in the dataset to the length of the longest example in the dataset
- use the tokenizer and
DataCollatorWithPadding (docs) to pad each example in a batch to the length of the longest example in the batch
clearly, scenario 2 is more efficient, especially in cases where a few examples happen to be much longer than the median length and scenario 1 would introduce a lot of unnecessary padding.
btw under the hood, the data collators are fed to the
collate_fn argument of pytorch’s
DataLoader, see e.g. here: transformers/trainer.py at 4e7bf94e7280d2b725ac4644dbe9808560afa5d8 · huggingface/transformers · GitHub
the pytorch docs are not super informative on
collate_fn itself, but you can find various discussion in their forums (e.g. here)