Hi,
I have been trying to do inference of a model I’ve finetuned for a large dataset.
I’ve done it this way: Summary of the tasks
Iterating over all the questions and contexts but it’s too slow.
This way from the course seems to be quite ok but I run into memory issues, assuming because the whole dataset is in a dict?
batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(
device
)
with torch.no_grad():
outputs = trained_model(**batch)
Is there some way I can pass the dataset like I would in lightning directly and iterate over the batches dynamically?
I.e. instead of getting the batch manually as above, do something like
for batch in iter(dataset):
pred = model(**batch)
?
Thanks a lot in advance