Hi Everyone,
I am fine tuning an llm on multi gpu with accelerate. I also need to add a callback to the trainer to sample and log predictions using weights and biases.
class WandbLLMSampleCallback(WandbCallback):
def __init__(
self,
trainer,
test_dataset,
num_samples=10,
max_new_tokens=256,
log_model="checkpoint",
):
super().__init__()
...
def on_evaluate(self, args, state, control, **kwargs):
super().on_evaluate(args, state, control, **kwargs)
# make sure evaluate is called only on the main process once
if state.is_world_process_zero:
records_table = self._samples_table(self.sample_dataset)
self._wandb.log({"sample_predictions": records_table})
...
I observed that the on_evaluate
method is called multiple times = to the number of processes.
My qs are:
In a distributed setup does accelerate create multiple trainer objects each with their own trainer state or one trainer object with multiple trainer states.
When we do trainer.add_callback(WandbInputLoggerCallback(tokenizer))
does each trainer get a new instance of the callback. If we already need to put a check like if state.is_world_process_zero
in the callback then does it make sense to even create the redundant callback instances and it to the other trainers ? Or should we do
if trainer.accelerator.is_main_process:
trainer.add_callback(WandbInputLoggerCallback(tokenizer))
Appreciate your advice on this.
Thanks
Anindya