Background
I am working on a research problem where the model has to make predictions on a dataset it itself generates during previous inference steps.
Problem description
Given the processing for previous inference steps use Trainer
s distributed inference capability, I was wondering how it is possible to use all workers to do inference on the data generated on the fly.
My proposed solutions
Inefficient approach
At the end of the predict method, we call torch.all_gather
so all previous task outputs are available to all workers. Therefore, all the processes have access to the raw data and they can construct a Dataset
, process it and run inference on it. Because the dataset gets wrapped by a DataLoader
inside the Trainer
predict
, it will be indexed with different indices based on the local rank and world size. This is all well and good, but the issue is that the the data processing will happen in every process.
Better approach but unsure if correct
We could ensure that we do the processing on the head node, rank 0, save the dataset and then load it in the other processes. Assuming self
is an instance of Trainer
I would write the following code:
# block replicas and execute processing in main process of node 0, rank 0
with self.args.main_process_first(
local=False, desc="Constructing and processing dataset on the fly"
):
# process_index is from [0, world_size]
if self.args.process_index == 0:
dataset = datasets.Dataset(dataset_dict)
# process dataset
dataset.map(...)
# save to disk
dataset.save_to_disk("path/to/dataset/directory")
else:
dataset = load_from_disk("path/to/dataset/directory")
# call .predict here...
The above code would be called in a custom trainer method that handles inference for multiple tasks and delegates calls to Trainer.predict
once the tasks are ready to run.
The logic of my approach is to use save_to_disk
and load_from_disk
to emulate what happens in a lot of the example training scripts (eg run_summarisation.py
) where the initial processing happens on the head node and the replicas are blocked and load the dataset from a cache file as opposed to processing the data multiple times.
Feedback on my understanding & proposal are appreciated! @lhoestq I know you are a dataset expert so let me know if my approach seems correct!