Distributed inference for datasets created on the fly


I am working on a research problem where the model has to make predictions on a dataset it itself generates during previous inference steps.

Problem description

Given the processing for previous inference steps use Trainers distributed inference capability, I was wondering how it is possible to use all workers to do inference on the data generated on the fly.

My proposed solutions

Inefficient approach

At the end of the predict method, we call torch.all_gather so all previous task outputs are available to all workers. Therefore, all the processes have access to the raw data and they can construct a Dataset, process it and run inference on it. Because the dataset gets wrapped by a DataLoader inside the Trainer predict, it will be indexed with different indices based on the local rank and world size. This is all well and good, but the issue is that the the data processing will happen in every process.

Better approach but unsure if correct

We could ensure that we do the processing on the head node, rank 0, save the dataset and then load it in the other processes. Assuming self is an instance of Trainer I would write the following code:

    # block replicas and execute processing in main process of node 0, rank 0
    with self.args.main_process_first(
            local=False, desc="Constructing and processing dataset on the fly"
	    # process_index is from [0, world_size]
        if self.args.process_index == 0:
            dataset = datasets.Dataset(dataset_dict)
  			# process dataset 
  			# save to disk
			dataset = load_from_disk("path/to/dataset/directory")		  
          # call .predict here...

The above code would be called in a custom trainer method that handles inference for multiple tasks and delegates calls to Trainer.predict once the tasks are ready to run.

The logic of my approach is to use save_to_disk and load_from_disk to emulate what happens in a lot of the example training scripts (eg run_summarisation.py) where the initial processing happens on the head node and the replicas are blocked and load the dataset from a cache file as opposed to processing the data multiple times.

Feedback on my understanding & proposal are appreciated! @lhoestq I know you are a dataset expert so let me know if my approach seems correct!

It sounds good to me ! Though I’ve not heard about this kind of setup before.

Where does the dataset_dict come from ? Is it gathered from all the nodes ?

In that case maybe each node could do the processing on their side and save_to_disk - and at evaluation time the nodes can reload all the partial datasets, merge them and start the distributed evaluation.

1 Like

Hi @lhoestq, thanks so much for engaging. I agree this is unusual, it’s new research with LLMs I am carrying out for my PhD. What is happening here is that I get the LLM to do error correction on some output it generated before and there is a deterministic program that checks whether the output should be corrected or not. This program works on strings, so I need to detetokenize and use some outside information to determine whether I should query the model again for a given example or not. Since compute_metrics runs at the end of the generation step, converts everything to strings and I can pass to it the metadata it needs to correct the errors, I thought I could also use it to build dataset_dict. I wasn’t sure how to make sure that not all workers will end up pre-processing the resulting dataset, thanks so much for your input.