ValueError: Query/Key/Value should either all have the same dtype, or (in the quantized case) Key/Value should have dtype torch.int32

-enable_xformers_memory_efficient_attention fails with the following error when run with accelerate

  File "/anaconda/envs/diffusers-ikin/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 348, in _memory_efficient_attention_forward_requires_grad
    inp.validate_inputs()
  File "/anaconda/envs/diffusers-ikin/lib/python3.8/site-packages/xformers/ops/fmha/common.py", line 121, in validate_inputs
    raise ValueError(
ValueError: Query/Key/Value should either all have the same dtype, or (in the quantized case) Key/Value should have dtype torch.int32
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16
Steps:   0%|                                           | 0/1000 [00:02<?, ?it/s]
[2023-11-25 14:03:17,898] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 5264) of binary: /anaconda/envs/diffusers-ikin/bin/python
Traceback (most recent call last):
  File "/anaconda/envs/diffusers-ikin/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/anaconda/envs/diffusers-ikin/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/anaconda/envs/diffusers-ikin/lib/python3.8/site-packages/accelerate/commands/launch.py", line 985, in launch_command
    multi_gpu_launcher(args)
  File "/anaconda/envs/diffusers-ikin/lib/python3.8/site-packages/accelerate/commands/launch.py", line 654, in multi_gpu_launcher
    distrib_run.run(args)
  File "/anaconda/envs/diffusers-ikin/lib/python3.8/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/anaconda/envs/diffusers-ikin/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/anaconda/envs/diffusers-ikin/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py FAILED

I am running accelerate as following

accelerate launch diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py \
      --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
      --instance_data_dir={input_dir} \
      --output_dir={output_dir} \
      --instance_prompt=instance_prompt \
      --mixed_precision="fp16" \
      --resolution=1024 \
      --train_batch_size=1 \
      --gradient_accumulation_steps=4 \
      --learning_rate=1e-4 \
      --lr_scheduler="constant" \
      --lr_warmup_steps=0 \
      --checkpointing_steps=500 \
      --max_train_steps=1000 \
      --seed="0" \
      --checkpoints_total_limit=5 \
      --enable_xformers_memory_efficient_attention

Without enable_xformers_memory_efficient_attention flag training works fine
Accelerate config

{
  "compute_environment": "LOCAL_MACHINE",
  "debug": false,
  "distributed_type": "MULTI_GPU",
  "downcast_bf16": false,
  "machine_rank": 0,
  "main_training_function": "main",
  "mixed_precision": "no",
  "num_machines": 1,
  "num_processes": 2,
  "rdzv_backend": "static",
  "same_network": false,
  "tpu_use_cluster": false,
  "tpu_use_sudo": false,
  "use_cpu": false
}

Versions:

xformers==0.0.23.dev687
accelerate==0.24.1
torch==2.1.0
torchvision==0.16.1

In the Automatic1111 webui there is option “Upcast cross attention layer to float32” with this code:

    dtype = q.dtype
    if shared.opts.upcast_attn:
        q, k, v = q.float(), k.float(), v.float()
    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
    out = out.to(dtype)

Link to fix: