Hi,
I pass a test dataset to trainer.predict but I have many samples. Therefore, I get a memory error. Does the library support a way of batch based trainer.predict? or do I have to implement it myself?
Hi,
I pass a test dataset to trainer.predict but I have many samples. Therefore, I get a memory error. Does the library support a way of batch based trainer.predict? or do I have to implement it myself?
You can pass eval_accumulation_steps=xxx
to pass the predictions to the CPU every xxx
steps, this should help.
You can set the batch size manually using trainer.prediction_loop()
Instead of using trainer.predict(test_dataset)
, you can use torch DataLoader for trainer.prediction_loop(). Thus, you might change
from
raw_pred, _, _ = trainer.predict(test_dataset)
into:
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
raw_pred, _, _ = trainer.prediction_loop(test_loader, description="prediction")