Do Trainer and Callback get created multiple times in case of distributed setup

Hi Everyone,

I am fine tuning an llm on multi gpu with accelerate. I also need to add a callback to the trainer to sample and log predictions using weights and biases.

class WandbLLMSampleCallback(WandbCallback):
    def __init__(
        self,
        trainer,
        test_dataset,
        num_samples=10,
        max_new_tokens=256,
        log_model="checkpoint",
    ):
        super().__init__()
       ...
    def on_evaluate(self, args, state, control, **kwargs):
        super().on_evaluate(args, state, control, **kwargs)
        # make sure evaluate is called only on the main process once
        if state.is_world_process_zero:
            records_table = self._samples_table(self.sample_dataset)
            self._wandb.log({"sample_predictions": records_table})
...

I observed that the on_evaluate method is called multiple times = to the number of processes.

My qs are:
In a distributed setup does accelerate create multiple trainer objects each with their own trainer state or one trainer object with multiple trainer states.

When we do trainer.add_callback(WandbInputLoggerCallback(tokenizer)) does each trainer get a new instance of the callback. If we already need to put a check like if state.is_world_process_zero in the callback then does it make sense to even create the redundant callback instances and it to the other trainers ? Or should we do

if trainer.accelerator.is_main_process:
    trainer.add_callback(WandbInputLoggerCallback(tokenizer))

Appreciate your advice on this.

Thanks
Anindya

1 Like

could you share your code for callback? I am trying to do this but my code gets stuck during callback when using accelerate for generations?
thanks

1 Like