### π Describe the bug
There appears to be something wrong the the MPS cache, I⦠appears that either its not releasing memory when it ideally should be, or the freeable memory in the cache is not being taken into account when the check for space occurs.
The issue occurs on the currently nightly, see versions, and 2.0.1
This issue affects performance at best and terminates an application at worse.
Here's an example...
```
from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline
from torch import mps
import torch
import fp16fixes
import gc
fp16fixes.fp16_fixes()
pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16)
pipe_prior.to("mps")
prompt = "A car exploding into colorful dust"
out = pipe_prior(prompt)
image_emb = out.image_embeds
zero_image_emb = out.negative_image_embeds
pipe_prior = None
gc.collect()
mps.empty_cache()
pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
pipe.to("mps")
pipe.enable_attention_slicing()
image = pipe(
image_embeds=image_emb,
negative_image_embeds=zero_image_emb,
height=1024,
width=1024,
num_inference_steps=30,
).images
image[0].save("cat.png")
```
This works on a 8GB M1 Mac Mini without issue the two models run at
```
100%|ββββββββ| 25/25 [00:07<00:00, 3.15it/s]
100%|ββββββββ| 30/30 [04:24<00:00, 8.82s/it]
```
Remove the `mps.empty_cache()` and it fails during the second model run
```
0%| | 0/30 [00:03<?, ?it/s]
Traceback (most recent call last):
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/8GB_M1_Diffusers_Scripts/sag/k2img.py", line 25, in <module>
image = pipe(
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py", line 272, in __call__
noise_pred = self.unet(
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 905, in forward
sample, res_samples = downsample_block(
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py", line 1662, in forward
hidden_states = attn(
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 321, in forward
return self.processor(
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 1590, in __call__
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 374, in get_attention_scores
attention_probs = attention_scores.softmax(dim=-1)
RuntimeError: MPS backend out of memory (MPS allocated: 3.90 GB, other allocations: 4.94 GB, max allowed: 9.07 GB). Tried to allocate 387.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).
```
If I reduce the height and width values to 512 it'll run to completion but the second model runs at 40 seconds per iter with a lot of swap file access. With the cache emptied manually it runs at around 2 seconds per iter.
the fp16fixes file is required to work around some issues with using fp16 on mps which fails with a broadcast error on 2.0.1 and fails with a bad image on the nightly I'm currently using. If I remove it the issue still occurs on the nightly.
```
% cat fp16fixes.py
import torch
def fp16_fixes():
if torch.backends.mps.is_available():
torch.empty = torch.zeros
_torch_layer_norm = torch.nn.functional.layer_norm
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
if weight is not None:
weight = weight.float()
if bias is not None:
bias = bias.float()
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
else:
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
torch.nn.functional.layer_norm = new_layer_norm
def new_torch_tensor_permute(input, *dims):
result = torch.permute(input, tuple(dims))
if input.device == "mps" and input.dtype == torch.float16:
result = result.contiguous()
return result
torch.Tensor.permute = new_torch_tensor_permute
```
### Versions
Collecting environment information...
PyTorch version: 2.1.0.dev20230724
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.4.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.3 (clang-1403.0.22.14.1)
CMake version: version 3.24.4
Libc version: N/A
Python version: 3.10.11 (main, Apr 8 2023, 02:11:11) [Clang 14.0.0 (clang-1400.0.29.202)] (64-bit runtime)
Python platform: macOS-13.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1
Versions of relevant libraries:
[pip3] numpy==1.25.1
[pip3] torch==2.1.0.dev20230724
[pip3] torchvision==0.15.2
[conda] Could not collect
cc @ezyang @gchanan @zou3519 @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev