Enabling gradient checkpointing and deepspeed ZeRO3 raise train failure

I’m using diffusers and deepspeed to train my Diffusion Models. All my code works fine until I upgrade the diffusers to the newest version.

Issue

When using DeepSpeed ZeRO3 and enabling unet gradient checkpointing at the same time, training with diffusers <= 0.16.1 works fine but training with diffusers >= 0.17.0 will raise following exception:

(I’ve checked 0.15.1, 0.16.1, 0.17.0, 0.17.1, 0.18.2, 0.19.1, and 0.20.2)

Either disabling gradient checkpointing or using DeepSpeed ZeRO2 will fix this issue.

Traceback (most recent call last):
  File "/share/project/zhangfan/codes/ds_diffuser/minimal_example.py", line 220, in <module>
    main()
  File "/share/project/zhangfan/codes/ds_diffuser/minimal_example.py", line 213, in main
    model.backward(loss)
  File "/share/project/zhangfan/misc/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/share/project/zhangfan/misc/conda/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1845, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/share/project/zhangfan/misc/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/share/project/zhangfan/misc/conda/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1962, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/share/project/zhangfan/misc/conda/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 62, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/share/project/zhangfan/misc/conda/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/share/project/zhangfan/misc/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: The size of tensor a (0) must match the size of tensor b (1280) at non-singleton dimension 1

Reproduction

I wrote a minimal code to reproduce this issue:

import argparse
import datetime
import os
import traceback

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

import deepspeed
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler


class MinimalDiffusion(nn.Module):

    def __init__(self) -> None:
        super().__init__()

        pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"

        self.scheduler = DDPMScheduler.from_pretrained(
            pretrained_model_name_or_path, subfolder="scheduler",
        )
        self.unet = UNet2DConditionModel.from_pretrained(
            pretrained_model_name_or_path, subfolder="unet",
        )
        self.vae = AutoencoderKL.from_pretrained(
            pretrained_model_name_or_path, subfolder="vae",
        )

        self.vae.requires_grad_(False)

        self.unet.enable_xformers_memory_efficient_attention()
        self.unet.enable_gradient_checkpointing()

    def forward(self, **kwargs):
        if self.training:
            return self._forward_train(**kwargs)
        else:
            return self._forward_eval(**kwargs)

    def _forward_train(
        self,
        *,
        vae_t_image,
        encoder_hidden_states,
        **_,
    ):
        latents = self.vae.encode(vae_t_image).latent_dist.sample()
        latents = latents * self.vae.config.scaling_factor

        bsz, ch, h, w = latents.shape
        device = latents.device

        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bsz,), device=device)
        timesteps = timesteps.long()
        noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)

        model_pred = self.unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states,
        ).sample

        loss = F.mse_loss(model_pred, noise, reduction="mean")

        return loss

    def _forward_eval(**kwargs):
        pass

    def train(self, mode: bool = True):
        self.training = mode
        self.vae.eval()
        self.unet.train(mode)
        return self

def init_distributed_mode(args):
    assert torch.cuda.is_available() and torch.cuda.device_count() > 1

    args.distributed = True
    args.rank = int(os.environ["RANK"])
    args.world_size = int(os.environ['WORLD_SIZE'])
    args.local_rank = int(os.environ['LOCAL_RANK'])
    args.dist_backend = "nccl"

    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(
        backend=args.dist_backend,
        init_method=args.dist_url,
        world_size=args.world_size,
        rank=args.rank,
        timeout=datetime.timedelta(0, 7200)
    )
    dist.barrier()

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.allow_tf32 = True


def get_parameter_groups(model):
    parameter_group_names = {}
    parameter_group_vars = {}

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        if param.ndim <= 1:
            group_name = "wo_decay"
            weight_decay = 0.
        else:
            group_name = "w_decay"
            weight_decay = 0.01

        lr = 5e-5

        if group_name not in parameter_group_names:
            parameter_group_names[group_name] = {
                "params": [],
                "weight_decay": weight_decay,
                "lr": lr,
            }
            parameter_group_vars[group_name] = {
                "params": [],
                "weight_decay": weight_decay,
                "lr": lr,
            }

        parameter_group_vars[group_name]["params"].append(param)
        parameter_group_names[group_name]["params"].append(name)

    return list(parameter_group_vars.values())


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--distributed", default=False, action="store_true")
    parser.add_argument("--world-size", default=1, type=int)
    parser.add_argument("--rank", default=-1, type=int)
    parser.add_argument("--gpu", default=-1, type=int)
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument("--dist-url", default="env://")
    parser.add_argument("--dist-backend", default="nccl", type=str)
    deepspeed.add_config_arguments(parser)

    return parser.parse_args()


def main():
    args = parse_args()
    args.deepspeed_config = "deepspeed_config.json"
    print(args)
    model = MinimalDiffusion()
    parameters = get_parameter_groups(model)

    model, optimizer, _, _ = deepspeed.initialize(
        args=args,
        model=model,
        model_parameters=parameters,
    )

    device = torch.device("cuda")
    dtype = torch.bfloat16
    model.train()
    for _ in range(5):
        vae_t_image = torch.randn(32, 3, 512, 512, dtype=dtype, device=device)
        encoder_hidden_states = torch.randn(32, 77, 768, dtype=dtype, device=device)

        loss = model(
            vae_t_image=vae_t_image,
            encoder_hidden_states=encoder_hidden_states,
        )

        model.backward(loss)
        model.step()
        torch.cuda.synchronize()


if __name__ == "__main__":
    try:
        main()
    except Exception as ex:
        print(ex)
        print(traceback.format_exc())

Here is my deepspeed_config.json:

{
  "train_micro_batch_size_per_gpu": 32,
  "gradient_accumulation_steps": 1,
  "steps_per_print": 1000,
  "flops_profiler": {
    "enabled": true,
    "profile_step": -1,
    "module_depth": -1,
    "top_modules": 1,
    "detailed": true
  },
  "zero_allow_untested_optimizer": true,
  "activation_checkpointing": {
    "partition_activations": true,
    "contiguous_memory_optimization": true,
    "profile": true
  },
  "optimizer": {
    "type": "Adam",
    "adam_w_mode": true,
    "params": {
      "lr": 5e-05,
      "weight_decay": 0.01,
      "bias_correction": true,
      "betas": [
        0.9,
        0.999
      ],
      "eps": 1e-08
    }
  },
  "scheduler": {
    "type": "WarmupDecayLR",
    "params": {
      "total_num_steps": 20000,
      "warmup_min_lr": 0,
      "warmup_max_lr": 5e-05,
      "warmup_num_steps": 5000,
      "warmup_type": "linear"
    }
  },
  "bf16": {
    "enabled": true
  },
  "gradient_clipping": 1.0,
  "zero_optimization": {
    "stage": 3,
    "contiguous_gradients": true,
    "overlap_comm": false,
    "reduce_scatter": true,
    "reduce_bucket_size": 27000000.0,
    "allgather_bucket_size": 27000000.0,
    "stage3_gather_16bit_weights_on_model_save": false,
    "stage3_prefetch_bucket_size": 27000000.0,
    "stage3_param_persistence_threshold": 51200,
    "sub_group_size": 1000000000.0,
    "stage3_max_live_parameters": 1000000000.0,
    "stage3_max_reuse_distance": 1000000000.0
  }

My launch scripts is deepspeed --hostfile=hostfile minimal_example.py

System info and environment

hardware: 8 x A100 40G

system: Ubuntu 20.04.6 LTS

nvidia driver version: 470.182.03

cuda version: 11.4

pkgs version:

    deepspeed==0.9.2
    torch==2.0.1
    diffusers>=0.15.1,<=0.20.2