I’ll give you the full picture.
The workflow:
You create an instance of GlueDataset(data_args, tokenizer)
. Then you pass it to Trainer(...)
class. In trainer, you also pass in default_data_collator
. The reason is that GlueDataset
return InputExample
which is HF specific and cannot be used by Pytorch dataloader directly. So the default_data_collator
takes in List[InputExamples]
and returns a dict. This dict is then used by the dataloader.
So in trainer, if you pass default_data_collator
with TensorDataset
, it won’t work directly (That’s why you’re getting the error). This error is raised when dataloader will pass the batch to default_data_collator
. I’d suggest using the default Pytorch collate_fn
with your TensorDataset
, it would work just fine.
One more additional thing:
Make sure the dataloader returns the dict with same key values forward
method expects.
Inside _training_step
, you’ll pass inputs
to the function, and then after the inputs are passed kept on gpu, the function does:
output = model(**inputs)
In this case, the keyword arguments have to match. In case, they don’t, you can inherit from Trainer
and redefine your own method.
I hope this answers your question.