How to Setup Deferred Init with Accelerate + DeepSpeed?

I’d like to defer initialization of my model until after DeepSpeed finishes sharding to avoid running OOM. Any tips on how to set that up?

I found the init_empty_weights method to help put the model on the “meta” device, but I’m not sure how to initialize the weights after calling accelerator.prepare in such a way that’s compatible with DeepSpeed.

In case it’s relevant, want to note I’m using a custom model not a huggingface model.

1 Like

I have the same question. When I use init_empty_weights() and load_checkpoint_and_dispatch I don’t get OOM errors when loading the model. But with DeepSpeed I do. Any solutions?

1 Like

May need something like with deepspeed.zero.Init():.

Shouldn’t the Trainer already take care of this when running with Accelerate and Deepspeed?

1 Like

I think yes as far as deepspeed= arg. is passed to TrainingArguments… Except for some pitfalls.

training_args = TrainingArguments(..., deepspeed="ds_config_zero3.json")
trainer = Trainer(...)

I have the TrainingArguments before the Trainer as well. The thing is, when I dont use deepspeed.zero.Init(), I don’t get this in my logs:

Detected DeepSpeed ZeRO-3: activating zero.init() for this model

But when I do use it, I have that in my logs but it is not compatible with quantization. it give errors when importing the model. But we already know that deepspeed stage 3 is compatible with qlora. So I’m wondering why I don’t get that activating zero.init() in my logs if the Trainer already works with deepspeed.

This is my setup by the way:

# ------------------------- Training arguments -------------------------
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=2e-5,
        per_device_train_batch_size=1 ,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=16,
        num_train_epochs= 3,
        max_steps= -1,
        weight_decay=0.01,
        logging_strategy="steps",
        logging_steps= 25,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps= 25,
        save_steps= 25,
        save_total_limit=2,
        load_best_model_at_end= True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        label_names=["labels"],
        max_grad_norm=1.0,
        bf16=True
        log_level="info",
        log_level_replica="warning",
        remove_unused_columns=False,
        eval_accumulation_steps=16,
    )

    # ------------------------- Load model and quantization-------------------------
    optional_kwargs = {}

    if args.quant and LLAMA:
        print("🔢 Using 4-bit quantization...")
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=DTYPE,
            bnb_4bit_quant_storage=DTYPE,
        )
        optional_kwargs["quantization_config"] = quant_config
    else:
        print("🔢 Loading model without quantization...")

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        local_files_only=True,
        trust_remote_code=True,
        attn_implementation="sdpa",
        torch_dtype=DTYPE,
        **optional_kwargs
    )

    model.config.pad_token_id = tokenizer.pad_token_id
    if LLAMA:
        model.config.pretraining_tp = 1

    # ------------------------- Gradient Checkpointing -------------------------
    training_args.gradient_checkpointing = True
    model.config.use_cache = False
    training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

    # ------------------------- LORA -------------------------
    if args.lora:
        print("✨ Applying LoRA...")
        lora_config = LoraConfig(
            r=8,
            lora_alpha=16,
            target_modules="all-linear",
            lora_dropout=0.1,
            bias="none",
            task_type=TaskType.CAUSAL_LM,
        )
        model = prepare_model_for_kbit_training(model)
        model = get_peft_model(model, lora_config)

        model.print_trainable_parameters()

    # ------------------------- Trainer -------------------------
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=final_dataset["train"],
        eval_dataset=final_dataset["test"],
        data_collator=data_collator,
        # compute_metrics=custom_metrics,
        callbacks=trainer_callbacks
    )

    # ---------------------------- Train ----------------------------
    trainer.train()

accelerate config:

compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
  deepspeed_config_file: "path/to/deepspeed_config_bf16.json"
  zero3_init_flag: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

deepspeed config:

{
    "bf16": { "enabled": true },
    "fp16": { "enabled": false },
  
    "zero_optimization": {
      "stage": 3,
      "overlap_comm": true,
      "reduce_scatter": true,
      "stage3_gather_16bit_weights_on_model_save": true,
  
      "offload_optimizer": {
        "device": "cpu",
        "pin_memory": true
      },
      "offload_param": {
        "device": "cpu",
        "pin_memory": true
      }
    },
  
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": 1.0,
  
    "wall_clock_breakdown": false
  }
 

Launch command:

accelerate launch --config_file=accelerate_config_deepspeed.yaml --num_machines=$SLURM_NNODES --machine_rank=$SLURM_NODEID causal_lm.py --config causal_lm_config.yaml

1 Like

Perhaps related to this?