I am trying to train a model in Google Colab using TPU distributed training. The dataset is my own and few functions: get_train_dataloader and get_eval_dataloader are overloaded to use batch sampler. And I spawn the process using following code:
This works flawlessly until I enable wandb for reporting. The training stucks at the begining showing 2 steps performed out of several. I noticed only 7 tqdm progress bar out of 8 are displayed when this happens. Seems like Trainer is waiting for last process forever. I also tried disabling WANDB_WATCH with following code, but no effect unless I disabled whole wandb.
seems only doing init with main worker worked for me and it stopped having the 8 workers all logging to wandb, code:
report_to = "none"
report_to = "wandb"
if report_to != 'none' and is_main_process_using_local_rank():
wandb.init(project="proj", entity="me", name='run_name_clean_wandb_dist', group='group_expt_name')
def is_main_process_using_local_rank() -> bool:
"""
Determines if it's the main process using the local rank.
based on print statements:
local_rank=0
local_rank=1
other ref:
# - set up processes a la l2l
local_rank: int = get_local_rank()
print(f'{local_rank=}')
init_process_group_l2l(args, local_rank=local_rank, world_size=args.world_size, init_method=args.init_method)
rank: int = torch.distributed.get_rank() if is_running_parallel(local_rank) else -1
args.rank = rank # have each process save the rank
set_devices_and_seed_ala_l2l(args) # args.device = rank or .device
print(f'setup process done for rank={args.rank}')
"""
local_rank: int = get_local_rank()
return local_rank == -1 or local_rank == 0 # -1 means serial, 0 likely means parallel