In transformers 4.20.1,
args = TrainingArguments(output_dir=‘tmp_trainer’, per_device_eval_batch_size=16)
trainer = Trainer(model=model, args=args)
predictions = trainer.predict(pred_dataset)
In transformers 4.20.1,
args = TrainingArguments(output_dir=‘tmp_trainer’, per_device_eval_batch_size=16)
trainer = Trainer(model=model, args=args)
predictions = trainer.predict(pred_dataset)