11B model gets OOM after using deepspeed zero 3 setting with 8 32G V100

Can anyone help me with one-node distributed training? The idea of my custom model is to merge some layers of 2 7B llama models like MOE models do. So that yields an 11B model. I’m loading the model in torch.float16 except for the lm_head weight which is in torch.float32. I’m assuming this would require at least 11*2*2 = 44 GB of GPU RAM for training. I used estimate_zero3_model_states_mem_needs_all_live to check out the zero3 setting recommended and created an accelerate config file accordingly.

The model can be successfully loaded (and I’m guessing into 8 GPUs evenly cause their used RAMs are all somewhere around 6000 MiB). However, once the forward call starts, the code fails due to OOM error. BTW, 8 GPUs seem a bit too extravagant for a model of this size, but I got no luck with 2 & 4 GPUs either.
This is my config.yaml
In which compute environment are you running? This machine ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------Which type of machine are you using? multi-GPU How many different machines will you use (use more than 1 for multi-node training)? [1]: 1 Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: NO Do you wish to optimize your script with torch dynamo?[yes/NO]:NO Do you want to use DeepSpeed? [yes/NO]: yes Do you want to specify a json file to a DeepSpeed config? [yes/NO]: NO ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------What should be your DeepSpeed's ZeRO optimization stage? 3 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------Where to offload optimizer states? none ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------Where to offload parameters? none How many gradient accumulation steps you're passing in your script? [1]: 1 Do you want to use gradient clipping? [yes/NO]: NO Do you want to save 16-bit model weights when using ZeRO Stage-3? [yes/NO]: yes Do you want to enable deepspeed.zero.Initwhen using ZeRO Stage-3 for constructing massive models? [yes/NO]: NO How many GPU(s) should be used for distributed training? [1]:8 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------Do you wish to use FP16 or BF16 (mixed precision)? fp16

This is my training script.
I implemented my custom model by overwriting and inheriting classes in transformers library. The 2 models to be merged would be loaded using LlamaForCausalLM.from_pretrained() in MoeLlamaForCausalLM class. I would then copy the weights into my model layer by layer.

    model = MoeLlamaForCausalLM(config, trainable=True)

    for name, param in model.named_parameters():
        if param.ndim == 1:
            # cast the small parameters (e.g. layernorm) to fp32 for stability
            param.data = param.data.to(torch.float32)
        if "lora" in name:
            param.requires_grad = True
            param.requires_grad = False

    tokenizer = AutoTokenizer.from_pretrained(...)
    tokenizer.pad_token = tokenizer.eos_token

    dataset = load_dataset(...)

    training_config = {
        "learning_rate": 5e-5,
        "gradient_accumulation_steps": 1,
        "batch_size": 2,
        "num_epochs": 1,
        "save_step": 100,
        "save_dir": "./accelerate_output",
        "lr_scheduler_type": "linear",
        "num_warmup_steps": 0

    accelerator = Accelerator(mixed_precision='fp16')

    if accelerator.is_main_process:

    # Creates Dummy Optimizer if `optimizer` was specified in the config file else creates Adam Optimizer
    optimizer_cls = (
        if accelerator.state.deepspeed_plugin is None
        or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
        else DummyOptim
    optimizer = optimizer_cls(model.parameters(), lr=training_config["learning_rate"])

    train_dataloader = DataLoader(dataset, shuffle=True, batch_size=training_config["batch_size"])

    num_training_steps = training_config["num_epochs"] * len(train_dataloader)
     # Creates Dummy Scheduler if `scheduler` was specified in the config file else creates `args.lr_scheduler_type` Scheduler
    if (
        accelerator.state.deepspeed_plugin is None
        or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
        lr_scheduler = get_scheduler(
        lr_scheduler = DummyScheduler(
            optimizer, total_num_steps=num_training_steps, warmup_num_steps=training_config["num_warmup_steps"]

    train_dataloader, model, optimizer, lr_scheduler = accelerator.prepare(
        train_dataloader, model, optimizer, lr_scheduler

    progress_bar = tqdm(range(training_config["num_epochs"] * len(train_dataloader)), disable=not accelerator.is_main_process)
    step = 0
    for epoch in range(training_config["num_epochs"]):
        for batch in train_dataloader:
            outputs = model(**batch)
            loss = outputs.loss


            step += 1

        if accelerator.is_main_process and step % training_config["save_step"] == 0:
            unwrapped_model = accelerator.unwrap_model(model)

Thanks in advance for any advice!