Saving checkpoint is too slow with deepspeed

A custom model is nn.Module with self.lm(large language model, CLM). I’m trying to train the model with deepspeed zero3. I’m using A100 x8 gpus. I made ds_config.json file and ran sh file like this.
deepspeed finetune/ --num_gpus=8 --deepspeed ds_config.json ...

Also, I changed Trainer’s method “_save” like this.

  def _save(self, output_dir: Optional[str] = None):
      output_dir = output_dir if output_dir is not None else self.args.output_dir
      os.makedirs(output_dir, exist_ok=True)"Saving model checkpoint to %s", output_dir)

      state_dict = self.accelerator.get_state_dict(self.deepspeed)
      self.model.lm.save_pretrained(output_dir, state_dict=state_dict)


      # Good practice: save your training arguments together with the trained model, os.path.join(output_dir, "training_args.bin"))

But, saving checkpoint doesn’t finish, and I couldn’t get pytorch_model.bin file. I think it’s too slow to save checkpoint gathering all models’ parameters from 8 gpus.
How can I solve this problem?

Hi, I have the same problem with saving codet5-6b with zero3. The logs says pytorch_model.bin is saved but it is not there and the process hangs. Did you find a solution?

With 8xA100, i think you should want to use ZeRO-2, it’ll be a lot faster in training and in model saving too.

The reason ZeRO-3 saves the model is very slow because it only uses 1 GPU[0] to combine all layers, and only combines 1 layer at a time.

I’m using 4xL4 with ZeRO-2 to fine-tune sdxl. All 4 GPUs are running at 100% when the checkpoint is being saved but nothing is saved despite waiting for 30 mins. Is this expected or am I doing something wrong here?

I’m in a situation just like yours, do you have a good solution for it now? @jakobsal

Hey @jax2000, I’m getting back to this today so not yet. What I understand so far though is that you might have to set some environmental variables that will depend on your setup. I’ve found this Github thread to contain some interesting stuff to try out. This NVIDIA env variables docs is probably also useful.