Cannot train train transformer on Mac/MPS

I’m trying to replicate this notebook on my MacBook:

I get stuck here:

%%time
print(model.device)
trainer.train()

With these results:

Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.
mps:0
  0%|          | 0/15228 [00:13<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File <timed exec>:2

File ~/Documents/sombert/.venv/lib/python3.12/site-packages/transformers/trainer.py:1912, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1910         hf_hub_utils.enable_progress_bars()
   1911 else:
-> 1912     return inner_training_loop(
   1913         args=args,
   1914         resume_from_checkpoint=resume_from_checkpoint,
   1915         trial=trial,
   1916         ignore_keys_for_eval=ignore_keys_for_eval,
   1917     )

File ~/Documents/sombert/.venv/lib/python3.12/site-packages/transformers/trainer.py:2245, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2242     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   2244 with self.accelerator.accumulate(model):
-> 2245     tr_loss_step = self.training_step(model, inputs)
   2247 if (
   2248     args.logging_nan_inf_filter
   2249     and not is_torch_xla_available()
   2250     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2251 ):
   2252     # if loss is nan or inf simply add the average of previous logged losses
   2253     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
...
   2262     # remove once script supports set_grad_enabled
   2263     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2264 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

RuntimeError: Placeholder storage has not been allocated on MPS device!

As I’ve looks around, it seems like at one point it was required to send the model to the MPS device explicitly but now it should happen automatically. I printed out the model’s device in the sample above.

Here are my versions:

v3.12.3 (.venv) 
❯ pip list
Package            Version
------------------ -----------
accelerate         0.30.1
aiohttp            3.9.5
aiosignal          1.3.1
appnope            0.1.4
asttokens          2.4.1
attrs              23.2.0
certifi            2024.2.2
charset-normalizer 3.3.2
comm               0.2.2
datasets           2.19.1
debugpy            1.8.1
decorator          5.1.1
dill               0.3.8
executing          2.0.1
filelock           3.14.0
frozenlist         1.4.1
fsspec             2024.3.1
huggingface-hub    0.23.2
idna               3.7
ipykernel          6.29.4
ipython            8.24.0
jedi               0.19.1
Jinja2             3.1.4
jupyter_client     8.6.2
jupyter_core       5.7.2
MarkupSafe         2.1.5
matplotlib-inline  0.1.7
mpmath             1.3.0
multidict          6.0.5
multiprocess       0.70.16
nest-asyncio       1.6.0
networkx           3.3
numpy              1.26.4
packaging          24.0
pandas             2.2.2
parso              0.8.4
pexpect            4.9.0
pip                24.0
platformdirs       4.2.2
prompt_toolkit     3.0.45
psutil             5.9.8
ptyprocess         0.7.0
pure-eval          0.2.2
pyarrow            16.1.0
pyarrow-hotfix     0.6
Pygments           2.18.0
python-dateutil    2.9.0.post0
pytz               2024.1
PyYAML             6.0.1
pyzmq              26.0.3
regex              2024.5.15
requests           2.32.3
safetensors        0.4.3
six                1.16.0
stack-data         0.6.3
sympy              1.12.1
tokenizers         0.19.1
torch              2.3.0
tornado            6.4
tqdm               4.66.4
traitlets          5.14.3
transformers       4.42.0.dev0
typing_extensions  4.12.0
tzdata             2024.1
urllib3            2.2.1
wcwidth            0.2.13
xxhash             3.4.1
yarl               1.9.4

Any ideas?