Batch size for trainer.predict()


I pass a test dataset to trainer.predict but I have many samples. Therefore, I get a memory error. Does the library support a way of batch based trainer.predict? or do I have to implement it myself?

1 Like

You can pass eval_accumulation_steps=xxx to pass the predictions to the CPU every xxx steps, this should help.

1 Like

You can set the batch size manually using trainer.prediction_loop()

Instead of using trainer.predict(test_dataset), you can use torch DataLoader for trainer.prediction_loop(). Thus, you might change


raw_pred, _, _ = trainer.predict(test_dataset)


test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
raw_pred, _, _ = trainer.prediction_loop(test_loader, description="prediction")