Batch size for trainer.predict()

In transformers 4.20.1,

args = TrainingArguments(output_dir=‘tmp_trainer’, per_device_eval_batch_size=16)

trainer = Trainer(model=model, args=args)

predictions = trainer.predict(pred_dataset)