Batch size for trainer.predict()


I pass a test dataset to trainer.predict but I have many samples. Therefore, I get a memory error. Does the library support a way of batch based trainer.predict? or do I have to implement it myself?

1 Like

You can pass eval_accumulation_steps=xxx to pass the predictions to the CPU every xxx steps, this should help.

1 Like

You can set the batch size manually using trainer.prediction_loop()

Instead of using trainer.predict(test_dataset), you can use torch DataLoader for trainer.prediction_loop(). Thus, you might change


raw_pred, _, _ = trainer.predict(test_dataset)


test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
raw_pred, _, _ = trainer.prediction_loop(test_loader, description="prediction")

In transformers 4.20.1,

args = TrainingArguments(output_dir=ā€˜tmp_trainerā€™, per_device_eval_batch_size=16)

trainer = Trainer(model=model, args=args)

predictions = trainer.predict(pred_dataset)

Hi I tried this method, but I see that the prediction process is killed at 99% without generating the predictions. There are no Memory Issues. Looks like if I use trainer.prediction_loop() method, I cannot set the argument predict_with_gentrate=True, I am thinking this might be causing the problem, I am not sure though. I am very new to working on pre_trained models. Could you please let me know your insights, what could be the possible reason for this issue.