Distributed inference for datasets created on the fly

Background

I am working on a research problem where the model has to make predictions on a dataset it itself generates during previous inference steps.

Problem description

Given the processing for previous inference steps use Trainers distributed inference capability, I was wondering how it is possible to use all workers to do inference on the data generated on the fly.

My proposed solutions

Inefficient approach

At the end of the predict method, we call torch.all_gather so all previous task outputs are available to all workers. Therefore, all the processes have access to the raw data and they can construct a Dataset, process it and run inference on it. Because the dataset gets wrapped by a DataLoader inside the Trainer predict, it will be indexed with different indices based on the local rank and world size. This is all well and good, but the issue is that the the data processing will happen in every process.

Better approach but unsure if correct

We could ensure that we do the processing on the head node, rank 0, save the dataset and then load it in the other processes. Assuming self is an instance of Trainer I would write the following code:

    # block replicas and execute processing in main process of node 0, rank 0
    with self.args.main_process_first(
            local=False, desc="Constructing and processing dataset on the fly"
    ):
	    # process_index is from [0, world_size]
        if self.args.process_index == 0:
            dataset = datasets.Dataset(dataset_dict)
  			# process dataset 
  			dataset.map(...)
  			# save to disk
  			dataset.save_to_disk("path/to/dataset/directory")
		else:
			dataset = load_from_disk("path/to/dataset/directory")		  
          # call .predict here...

The above code would be called in a custom trainer method that handles inference for multiple tasks and delegates calls to Trainer.predict once the tasks are ready to run.

The logic of my approach is to use save_to_disk and load_from_disk to emulate what happens in a lot of the example training scripts (eg run_summarisation.py) where the initial processing happens on the head node and the replicas are blocked and load the dataset from a cache file as opposed to processing the data multiple times.

Feedback on my understanding & proposal are appreciated! @lhoestq I know you are a dataset expert so let me know if my approach seems correct!

It sounds good to me ! Though I’ve not heard about this kind of setup before.

Where does the dataset_dict come from ? Is it gathered from all the nodes ?

In that case maybe each node could do the processing on their side and save_to_disk - and at evaluation time the nodes can reload all the partial datasets, merge them and start the distributed evaluation.

1 Like

Hi @lhoestq, thanks so much for engaging. I agree this is unusual, it’s new research with LLMs I am carrying out for my PhD. What is happening here is that I get the LLM to do error correction on some output it generated before and there is a deterministic program that checks whether the output should be corrected or not. This program works on strings, so I need to detetokenize and use some outside information to determine whether I should query the model again for a given example or not. Since compute_metrics runs at the end of the generation step, converts everything to strings and I can pass to it the metadata it needs to correct the errors, I thought I could also use it to build dataset_dict. I wasn’t sure how to make sure that not all workers will end up pre-processing the resulting dataset, thanks so much for your input.

@lhoestq I was hoping if it would be possible to feedback on some design I have put together on an even more tricky generalisation of the above. First, a brief description of the problem.

Problem Training a dialogue agent to communicate to one or more parties while accounting for feedback from a simulation environment. For simplicity, we assume there is one party which we call “user”. Each party’s contribution to the conversation (eg “agent: Good morning, how can I help”) is called a turn. At turn k in a dialogue session the predictions of the agent depend on predictions made in previous turns.

At the end of the interaction, the agent may interact with an environment to receive feedback and some its predictions would be revised for future training epochs.

Hence, we have to construct our eval/training data as we predict, at the start of the eval loop we have just the user first turn. I want to leverage the Trainer to do predictions in a distributed fashion as I train. Here’s how I conceptualise the solution.

  1. Use a dataset generator to create the datasets from the agent prediction history. In outline form this looks a bit like this:
class DatasetGenerator:

    def __init__(
        self,
        environment: SimulationEnvironment,
        gold_transcripts: dict[DialogID, SessionTranscript],
        preprocessing_steps: Preprocessors,
        dialog_ids: set[DialogID],
        batch_size: int = 128,
    ):
        """
        environment
            Agent calls the environment after every turn to receive feedback
            about possible next steps.
        gold_transcripts:
            Ground truth conversation transcripts.
        preprocessing_steps:
            Transforms sequentially applied to convert transcripts into text-to-text-format.
        """
        self.environment = environment
        self.all_dialog_ids = dialog_ids
        self.dialogs_shard: set[DialogID] = set()
        # transcripts of the dialog sessions for which prediction is not yet complete
        self.ongoing_sessions: dict[DialogID, SessionTranscript] = {}
        # keep alive objects tracking agent predictions for evaluation purposes
        self.agent_predictions = {}
        self.gold_transcripts: dict[DialogID, SessionTranscript] = deepcopy(gold_transcripts)
        self.complete_sessions: dict[DialogID, SessionTranscript] = {}
        self.batch_size = batch_size
        self.current_dataset: Optional[datasets.Dataset] = None
        self.preprocessors = preprocessing_steps

    def initialise(self, global_rank: int):
        """initialises the ongoing session with the first conversation turn"""
        # distribute dialogues across workers to avoid sharing
        # state across processes
        self.distribute_dialogues(self.all_dialog_ids, global_rank)
        # initialise with transcript that contains first turn of every dialog ...
        self.ongoing_sessions = ...

    def finished(self) -> bool:
        """Finish generating datasets when all the dialogue sessions have been completed."""
        return not self.ongoing_sessions

    def distribute_dialogues(self, all_ids: list[DialogID], global_rank: int):
        """Shard dialogs across workers so that interaction with the environment
        can happen in parallel. The outputs are then synchronised across
        GPUs so that the next dataset can be generated."""
        ...

    def create_dataset(self) -> Dataset:
        """Create dataset to predict state updates and next agent action given its past predictions."""

        class dataset_generator:

            def __init__(self, prompt_generator: Callable):

                self.prompt_generator = prompt_generator
                self.generated = []

            def __call__(
                    self,
                    ongoing_sessions: dict[DialogID, DraftSessionTranscript],
                    dialogues: set[DialogID],
                    completed_sessions: set[DialogID]
            ) -> Generator[dict[FeatureName, str], None, None]:
                assert dialogues is not None
                for dial_id in dialogues:
                    if dial_id not in completed_sessions:
                        self.generated.append(dial_id)
                        example = {
                            'source': self.prompt_generator(ongoing_sessions[dial_id]), 'id': dial_id
                        }
                        yield example

        gen_kwargs = {
            "ongoing_sessions": self.ongoing_sessions,
            "dialogues": self.all_dialog_ids,
            "completed_sessions": set(self.complete_sessions.keys())
        }
        this_turn_dataset = Dataset.from_generator(
            dataset_generator(self.preprocessors["prompt_generator"]),
            gen_kwargs=gen_kwargs
        )
        return this_turn_dataset

    def prepare_for_prediction(self, dataset: Dataset) -> Dataset:
        """Tokenize dataset."""
        num_proc = self.preprocessors["tokenization"].num_proc
        dataset.map(
            self.preprocessors["tokenization"],
            batched=True,
            num_proc=num_proc
        )
        self.current_dataset = dataset
        return dataset

    def collect_complete_sessions(self):
        """Remove complete dialogue sessions from ongoing sessions."""
        ...
    
    def update_sessions_with_agent_prediction(self, predictions: list[str]):
        """Add the prediction to its corresponding ongoing session."""
        for pred, record in zip(predictions, self.current_dataset):
            session = self.ongoing_sessions[record["id"]]
            # extend dialog session with the model prediction
            ...
        self.collect_complete_sessions()

    def update_sessions_with_system_actions(self, predictions: list[str]):
        """Add the system action predictions to the ongoing sessions"""

        for pred, record in zip(predictions, self.current_dataset):
            session = self.ongoing_sessions[record["id"]]
            ...

    def update_sessions_with_feedback(self) -> dict[DialogID, SessionTranscript]:
        """Interact with the environment to get feedback."""
        ...
        assert self.dialogs_shard
        assert not self.ongoing_sessions
        transcripts_with_feedback = {}
        for dial_id in self.dialogs_shard:
            session = self.complete_sessions[dial_id]
            ...

        return transcripts_with_feedback

    def generated_dataset_for_feedback_grounded_prediction(self, transcripts: dict[DialogID, SessionTranscript]) -> Dataset:
        """Generate dataset for predicting agent's response to feedback."""
        self.ongoing_sessions = transcripts
        ...

I would then customise the eval and predict methods of the Trainer as follows:

class TrainerForDialogInference(Seq2SeqTrainer):

    def evaluate(
            self,
            eval_dataset: DatasetGenerator = None,
            ignore_keys: Optional[list[str]] = None,
            metric_key_prefix: str = "eval",
            **gen_kwargs,
    ) -> dict[str, float]:



        # omitting Seq2Seq Trainer kwargs handling here

        self._memory_tracker.start()
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        eval_dataset.initialise(self.accelerator.state.process_index)
        current_turn = 0
        while not eval_dataset.finished():
            self.compute_metrics.should_compute = False
            # create the same dataset in all processes in parallel
            dataset = eval_dataset.create_dataset()
            self.accelerator.wait_for_everyone()
            # tokenize the data once, replicas load the pre-processed data
            with self.accelerator.main_process_first():
                dataset = eval_dataset.prepare_for_prediction(dataset)
            eval_dataloader = self.get_eval_dataloader(dataset)
            # predict at nth turn in each dialogue session
            output = eval_loop(
                eval_dataloader,
                description=f"evaluation at turn {current_turn + 1}",
                # No point gathering the predictions if there are no metrics, otherwise we defer to
                # self.args.prediction_loss_only
                prediction_loss_only=False,
                ignore_keys=ignore_keys,
                metric_key_prefix=metric_key_prefix,
            )
            # extend the transcripts with predicted states and next turns
            eval_dataset.update_sessions_with_predicted_states_or_maybe_actions(
                output.predictions
            )
            logger.info(f"Predicted for turn: {current_turn + 1}")
            current_turn += 1
        # when all the replicas are done with the state update, interact with the simulation
        self.accelerator.wait_for_everyone()
        if self.accelerator.is_main_process:
            logger.info("Executing programs")
        # each process interacts with the simulation for a subset of dialogues
        # - then we collect the objects from all processes on the main process
        session_with_feedback = eval_dataset.update_sessions_with_feedback()
        sessions_for_feedback_grounded_prediction = [None for _ in self.accelerator.num_processes]
        dist.gather_object(
            session_with_feedback,
            sessions_for_feedback_grounded_prediction if self.accelerator.process_index == 0 else None,
            dst=0
        )
        feedback_grounded_dataset = eval_dataset.generated_dataset_for_feedback_grounded_prediction(
            merge_disjoint_dicts(sessions_for_feedback_grounded_prediction)
        )
        # pre-process the dataset in the main process, replicas will load pre-proc dataset
        with self.accelerator.main_process_first():
            feedback_grounded_dataset = eval_dataset.prepare_for_prediction(feedback_grounded_dataset)
        action_dataloader = self.get_eval_dataloader(feedback_grounded_dataset)
        # predict system actions in parallel across GPUs
        output = eval_loop(
            action_dataloader,
            description=f"system action prediction evaluator",
            prediction_loss_only=False,
            ignore_keys=ignore_keys,
            metric_key_prefix=metric_key_prefix
        )
        # update transcripts with predicted actions, ensuring that these transcripts
        # are comparable with ground truth
        eval_dataset.update_sessions_with_system_actions(output.predictions)

        # omitted code: ... metrics computation ...

        self.log(output.metrics)

        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
        self._memory_tracker.stop_and_update_metrics(output.metrics)

I was wondering if this looks roughly correct or you see any issues with regards to process management? In particular, I wasn’t quite sure if/how the replicas will see the tokenised datasets? Are there more efficient ways of doing what I set out to do? I do plan to open source this work when it comes to maturity, so hopefully the community will benefit from it!