Inference is slow on M1 Mac despite MPS Torch backend

I’ve been trying to run a simple phi-2 example on my M1 MacBook Pro:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_default_device("mps")  # <-------- MPS backend

model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

inputs = tokenizer('''def print_prime(n):
   """
   Print all primes between 1 and n
   """''', return_tensors="pt", return_attention_mask=False)

outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]
print(text)

(this was after installing both the stable and nightly versions of PyTorch with pip and conda/micromamba on Python 3.11)

However, inference still takes a good 40 seconds (only the model.generate(...) line).

(After reading MPS device appears much slower than CPU on M1 Mac Pro · Issue #77799 · pytorch/pytorch · GitHub, I made the same test with a cpu model and MPS is definitely faster than CPU, so at least no weird stuff going on)

On the other hand, using MLX and the mlx-lm library makes inference almost instantaneous.

Is this expected? Am I doing anything wrong?

same boat, super slow and eats a lot of RAM in the process.