Wav2Vec2.0 FineTuning distributed training

Hi, I’m trying to finetune Wav2vec2.0 with distributed training refer to fine-tune-wav2vec2 tutorial.

For distributed training. I just additionally added sharded_ddp=‘simple’ to training-arguments. And launch train.py with torch.distributed.launch.

But When I start training with launch code like below, training did not proceeded. And hanged in 0step.
And GPU usage is always 100%.

Is there any way to finetune wav2vec2.0 with multi-gpu training?

image

launch code

$ CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=2,3,4,5 taskset --cpu-list 23-34 python -m torch.distributed.launch --nproc_per_node 4 train.py```

trainer.py

class TrainManager:
  def __init__(self, config):
    self.config = config

   # ....
   # setting datasets and others
   #  ....

    self.training_args = TrainingArguments(
        output_dir=self.config.finetuned_model_path,
        group_by_length=False,
        sharded_ddp='simple',
        report_to='wandb',
        run_name=self.config.project_name,
        per_device_train_batch_size=self.config.batch_size,
        gradient_accumulation_steps=self.config.gradient_accumulation_steps,
        evaluation_strategy="steps",
        num_train_epochs=20,
        fp16=self.config.use_fp16,
        gradient_checkpointing=self.config.gradient_checkpointing,
        save_steps=self.config.save_steps,
        eval_steps=self.config.eval_steps,
        logging_steps=self.config.logging_steps,
        learning_rate=self.config.learning_rate,
        weight_decay=self.config.weight_decay,
        warmup_steps=self.config.warmup_steps,
        save_total_limit=self.config.save_total_limit,
        disable_tqdm=False,
        load_best_model_at_end=True
    )
    self.training_args = TrainingArguments(
        output_dir=self.config.finetuned_model_path,
        group_by_length=False,
        sharded_ddp='simple',
        report_to='wandb',
        run_name=self.config.project_name,
        per_device_train_batch_size=self.config.batch_size,
        gradient_accumulation_steps=self.config.gradient_accumulation_steps,
        evaluation_strategy="steps",
        num_train_epochs=20,
        fp16=self.config.use_fp16,
        gradient_checkpointing=self.config.gradient_checkpointing,
        save_steps=self.config.save_steps,
        eval_steps=self.config.eval_steps,
        logging_steps=self.config.logging_steps,
        learning_rate=self.config.learning_rate,
        weight_decay=self.config.weight_decay,
        warmup_steps=self.config.warmup_steps,
        save_total_limit=self.config.save_total_limit,
        disable_tqdm=False,
        load_best_model_at_end=True
    )

  def run(self):
    gc.collect()
    torch.cuda.empty_cache()
    self.trainer.train()
    wandb.finish()
    torch.cuda.empty_cache()

# train.py
if __name__ == '__main__':
    tm = TrainManager(config)
    tm.run()