Inference with Multi-Step Reasoning

Problem statement

As models get more powerful at human-like reasoning, we would like to leverage them to solve tasks which are too complex for zero-shot prediction or in-context learning. Certain problems can be decomposed into simpler problems that the model can solve without finetuning, and the outputs generated when solving these simpler tasks can be used as cues for generating further prompts to complete the final task. This is a bit like the model prompting itself.

What does this mean in practice?

In practice, we can use the standard tools to process our data for the initial task(s), whose prompts do not depend on the previous task outputs. However, we do not know in advance the prompts for the remainder of the tasks, because these depend on the labels generated at previous steps. Therefore, unlike in all helpful examples in the docs, we have to run calls to get the prompt and tokenize after the initial tasks have completed. Crucially, we would like to leverage the Trainer and its capability to run distributed inference, and also process the task data once while ensuring it is available to all replicas. A rough sketch of my proposed implementation is:

class TrainerForMultiStepReasoning(Seq2SeqTrainer):

    def multi_step_reasoning_predict(
            test_dataset: Union[Dict[str, Dataset], Dataset],
            ignore_keys: Optional[list[str]] = None,
            metric_key_prefix: str = "test",
    ) -> Union[PredictionOutput, Dict[str, PredictionOutput]]:
        """Extends `predict` method to enable it to make predictions on a
        sequence of tasks where later tasks prompts depend on earlier task outputs.

        if isinstance(test_dataset, dict):
            task_predictions = defaultdict()
            # note that compute_metrics becomes stateful and is task aware.
            # compatibility maintained by implementing __call__(predictions)
            for task in self.compute_metrics.task_iter:
                task_name =
                task_processed = task.processed
                # if task has no dependency then it's just a standard predict call
                if task_processed:
                    task_predictions[task_name] = self.predict(
                    # block replicas and execute processing in main process
                    # preprocessor.preprocess runs under the hood with load_from_cache_file=True
                    with self.args.main_process_first(
                            desc=f'Processing data for task {task_name}'
                        this_task_processed_dataset = preprocessor.preprocess(

                    task_predictions[task_name] = self.predict(
                        ignore_keys = ignore_keys,
                        metric_key_prefix = f"{metric_key_prefix}-{task_name}"
            return task_predictions

            return self.predict(

Things to note:

  • compute_metrics is stateful and manages the task ordering and storing relevant information for future tasks pre-processing. This is a bit hacky but then I want to minimally change the trainer, and this seems like a sensible option

My question

  1. Are there any pitfalls with using the main_process_first context manager in the way I am planning? My feeling is that this is going to block all replicas until the preprocessing is done, and that they will simply load the processed dataset from cache and proceed with the inference? @sgugger, doe my thinking/implementation make sense?
  2. The same machinery should work for evaluation as well? The only difference would be that the implementation would be likely included inside _maybe_log_save_evaluate.

@lhoestq, given our super productive discussion in the other thread, is my on-the-fly processing correctly sketched here too? Note that unlike the other post, here we just have to fill in part of the prompt based on a previous generation, we don’t build the data on the fly. preprocessor.preprocess looks at the previous generations (stored in self.compute_metrics.task_outputs), runs to compete the prompts and tokenizes the data.