I am trying to train a model in Google Colab using TPU distributed training. The dataset is my own and few functions: get_train_dataloader and get_eval_dataloader are overloaded to use batch sampler. And I spawn the process using following code:
xmp.spawn(_mp_fn, args=(), nprocs=8, start_method=‘fork’)
This works flawlessly until I enable wandb for reporting. The training stucks at the begining showing 2 steps performed out of several. I noticed only 7 tqdm progress bar out of 8 are displayed when this happens. Seems like Trainer is waiting for last process forever. I also tried disabling WANDB_WATCH with following code, but no effect unless I disabled whole wandb.
os.environ[“WANDB_WATCH”] = “false”