Is wandb in Trainer configured for distributed training?

I am trying to train a model in Google Colab using TPU distributed training. The dataset is my own and few functions: get_train_dataloader and get_eval_dataloader are overloaded to use batch sampler. And I spawn the process using following code:

xmp.spawn(_mp_fn, args=(), nprocs=8, start_method=‘fork’)

This works flawlessly until I enable wandb for reporting. The training stucks at the begining showing 2 steps performed out of several. I noticed only 7 tqdm progress bar out of 8 are displayed when this happens. Seems like Trainer is waiting for last process forever. I also tried disabling WANDB_WATCH with following code, but no effect unless I disabled whole wandb.

os.environ[“WANDB_WATCH”] = “false”

2 Likes

I am having an issue that all 8 processes are logging. Do you have that issue too?

related: Logging & Experiment tracking with W&B - #73 by brando

seems only doing init with main worker worked for me and it stopped having the 8 workers all logging to wandb, code:

report_to = "none"
report_to = "wandb"
if report_to != 'none' and is_main_process_using_local_rank():
        wandb.init(project="proj", entity="me", name='run_name_clean_wandb_dist', group='group_expt_name')

def is_main_process_using_local_rank() -> bool:
    """
    Determines if it's the main process using the local rank.

    based on print statements:
        local_rank=0
        local_rank=1

    other ref:
        # - set up processes a la l2l
        local_rank: int = get_local_rank()
        print(f'{local_rank=}')
        init_process_group_l2l(args, local_rank=local_rank, world_size=args.world_size, init_method=args.init_method)
        rank: int = torch.distributed.get_rank() if is_running_parallel(local_rank) else -1
        args.rank = rank  # have each process save the rank
        set_devices_and_seed_ala_l2l(args)  # args.device = rank or .device
        print(f'setup process done for rank={args.rank}')
    """
    local_rank: int = get_local_rank()
    return local_rank == -1 or local_rank == 0  # -1 means serial, 0 likely means parallel

@brando looks like you got sorted, the alternative is to use the group argument in wandb.init, you can read more here: