@lhoestq I was hoping if it would be possible to feedback on some design I have put together on an even more tricky generalisation of the above. First, a brief description of the problem.
Problem Training a dialogue agent to communicate to one or more parties while accounting for feedback from a simulation environment. For simplicity, we assume there is one party which we call “user”. Each party’s contribution to the conversation (eg “agent: Good morning, how can I help”) is called a turn. At turn k in a dialogue session the predictions of the agent depend on predictions made in previous turns.
At the end of the interaction, the agent may interact with an environment to receive feedback and some its predictions would be revised for future training epochs.
Hence, we have to construct our eval/training data as we predict, at the start of the eval loop we have just the user first turn. I want to leverage the Trainer
to do predictions in a distributed fashion as I train. Here’s how I conceptualise the solution.
- Use a dataset generator to create the datasets from the agent prediction history. In outline form this looks a bit like this:
class DatasetGenerator:
def __init__(
self,
environment: SimulationEnvironment,
gold_transcripts: dict[DialogID, SessionTranscript],
preprocessing_steps: Preprocessors,
dialog_ids: set[DialogID],
batch_size: int = 128,
):
"""
environment
Agent calls the environment after every turn to receive feedback
about possible next steps.
gold_transcripts:
Ground truth conversation transcripts.
preprocessing_steps:
Transforms sequentially applied to convert transcripts into text-to-text-format.
"""
self.environment = environment
self.all_dialog_ids = dialog_ids
self.dialogs_shard: set[DialogID] = set()
# transcripts of the dialog sessions for which prediction is not yet complete
self.ongoing_sessions: dict[DialogID, SessionTranscript] = {}
# keep alive objects tracking agent predictions for evaluation purposes
self.agent_predictions = {}
self.gold_transcripts: dict[DialogID, SessionTranscript] = deepcopy(gold_transcripts)
self.complete_sessions: dict[DialogID, SessionTranscript] = {}
self.batch_size = batch_size
self.current_dataset: Optional[datasets.Dataset] = None
self.preprocessors = preprocessing_steps
def initialise(self, global_rank: int):
"""initialises the ongoing session with the first conversation turn"""
# distribute dialogues across workers to avoid sharing
# state across processes
self.distribute_dialogues(self.all_dialog_ids, global_rank)
# initialise with transcript that contains first turn of every dialog ...
self.ongoing_sessions = ...
def finished(self) -> bool:
"""Finish generating datasets when all the dialogue sessions have been completed."""
return not self.ongoing_sessions
def distribute_dialogues(self, all_ids: list[DialogID], global_rank: int):
"""Shard dialogs across workers so that interaction with the environment
can happen in parallel. The outputs are then synchronised across
GPUs so that the next dataset can be generated."""
...
def create_dataset(self) -> Dataset:
"""Create dataset to predict state updates and next agent action given its past predictions."""
class dataset_generator:
def __init__(self, prompt_generator: Callable):
self.prompt_generator = prompt_generator
self.generated = []
def __call__(
self,
ongoing_sessions: dict[DialogID, DraftSessionTranscript],
dialogues: set[DialogID],
completed_sessions: set[DialogID]
) -> Generator[dict[FeatureName, str], None, None]:
assert dialogues is not None
for dial_id in dialogues:
if dial_id not in completed_sessions:
self.generated.append(dial_id)
example = {
'source': self.prompt_generator(ongoing_sessions[dial_id]), 'id': dial_id
}
yield example
gen_kwargs = {
"ongoing_sessions": self.ongoing_sessions,
"dialogues": self.all_dialog_ids,
"completed_sessions": set(self.complete_sessions.keys())
}
this_turn_dataset = Dataset.from_generator(
dataset_generator(self.preprocessors["prompt_generator"]),
gen_kwargs=gen_kwargs
)
return this_turn_dataset
def prepare_for_prediction(self, dataset: Dataset) -> Dataset:
"""Tokenize dataset."""
num_proc = self.preprocessors["tokenization"].num_proc
dataset.map(
self.preprocessors["tokenization"],
batched=True,
num_proc=num_proc
)
self.current_dataset = dataset
return dataset
def collect_complete_sessions(self):
"""Remove complete dialogue sessions from ongoing sessions."""
...
def update_sessions_with_agent_prediction(self, predictions: list[str]):
"""Add the prediction to its corresponding ongoing session."""
for pred, record in zip(predictions, self.current_dataset):
session = self.ongoing_sessions[record["id"]]
# extend dialog session with the model prediction
...
self.collect_complete_sessions()
def update_sessions_with_system_actions(self, predictions: list[str]):
"""Add the system action predictions to the ongoing sessions"""
for pred, record in zip(predictions, self.current_dataset):
session = self.ongoing_sessions[record["id"]]
...
def update_sessions_with_feedback(self) -> dict[DialogID, SessionTranscript]:
"""Interact with the environment to get feedback."""
...
assert self.dialogs_shard
assert not self.ongoing_sessions
transcripts_with_feedback = {}
for dial_id in self.dialogs_shard:
session = self.complete_sessions[dial_id]
...
return transcripts_with_feedback
def generated_dataset_for_feedback_grounded_prediction(self, transcripts: dict[DialogID, SessionTranscript]) -> Dataset:
"""Generate dataset for predicting agent's response to feedback."""
self.ongoing_sessions = transcripts
...
I would then customise the eval
and predict
methods of the Trainer
as follows:
class TrainerForDialogInference(Seq2SeqTrainer):
def evaluate(
self,
eval_dataset: DatasetGenerator = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
**gen_kwargs,
) -> dict[str, float]:
# omitting Seq2Seq Trainer kwargs handling here
self._memory_tracker.start()
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
eval_dataset.initialise(self.accelerator.state.process_index)
current_turn = 0
while not eval_dataset.finished():
self.compute_metrics.should_compute = False
# create the same dataset in all processes in parallel
dataset = eval_dataset.create_dataset()
self.accelerator.wait_for_everyone()
# tokenize the data once, replicas load the pre-processed data
with self.accelerator.main_process_first():
dataset = eval_dataset.prepare_for_prediction(dataset)
eval_dataloader = self.get_eval_dataloader(dataset)
# predict at nth turn in each dialogue session
output = eval_loop(
eval_dataloader,
description=f"evaluation at turn {current_turn + 1}",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=False,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
# extend the transcripts with predicted states and next turns
eval_dataset.update_sessions_with_predicted_states_or_maybe_actions(
output.predictions
)
logger.info(f"Predicted for turn: {current_turn + 1}")
current_turn += 1
# when all the replicas are done with the state update, interact with the simulation
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
logger.info("Executing programs")
# each process interacts with the simulation for a subset of dialogues
# - then we collect the objects from all processes on the main process
session_with_feedback = eval_dataset.update_sessions_with_feedback()
sessions_for_feedback_grounded_prediction = [None for _ in self.accelerator.num_processes]
dist.gather_object(
session_with_feedback,
sessions_for_feedback_grounded_prediction if self.accelerator.process_index == 0 else None,
dst=0
)
feedback_grounded_dataset = eval_dataset.generated_dataset_for_feedback_grounded_prediction(
merge_disjoint_dicts(sessions_for_feedback_grounded_prediction)
)
# pre-process the dataset in the main process, replicas will load pre-proc dataset
with self.accelerator.main_process_first():
feedback_grounded_dataset = eval_dataset.prepare_for_prediction(feedback_grounded_dataset)
action_dataloader = self.get_eval_dataloader(feedback_grounded_dataset)
# predict system actions in parallel across GPUs
output = eval_loop(
action_dataloader,
description=f"system action prediction evaluator",
prediction_loss_only=False,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix
)
# update transcripts with predicted actions, ensuring that these transcripts
# are comparable with ground truth
eval_dataset.update_sessions_with_system_actions(output.predictions)
# omitted code: ... metrics computation ...
self.log(output.metrics)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
self._memory_tracker.stop_and_update_metrics(output.metrics)
I was wondering if this looks roughly correct or you see any issues with regards to process management? In particular, I wasn’t quite sure if/how the replicas will see the tokenised datasets? Are there more efficient ways of doing what I set out to do? I do plan to open source this work when it comes to maturity, so hopefully the community will benefit from it!