seems only doing init with main worker worked for me and it stopped having the 8 workers all logging to wandb, code:
report_to = "none"
report_to = "wandb"
if report_to != 'none' and is_main_process_using_local_rank():
wandb.init(project="proj", entity="me", name='run_name_clean_wandb_dist', group='group_expt_name')
def is_main_process_using_local_rank() -> bool:
"""
Determines if it's the main process using the local rank.
based on print statements:
local_rank=0
local_rank=1
other ref:
# - set up processes a la l2l
local_rank: int = get_local_rank()
print(f'{local_rank=}')
init_process_group_l2l(args, local_rank=local_rank, world_size=args.world_size, init_method=args.init_method)
rank: int = torch.distributed.get_rank() if is_running_parallel(local_rank) else -1
args.rank = rank # have each process save the rank
set_devices_and_seed_ala_l2l(args) # args.device = rank or .device
print(f'setup process done for rank={args.rank}')
"""
local_rank: int = get_local_rank()
return local_rank == -1 or local_rank == 0 # -1 means serial, 0 likely means parallel