**Issue description**
When uploading safetensors files as part of the `mlx_lm.f…use` step, all the weights files with `.safetensors` extensions are missing the optional metadata for format attribute. As a result, the uploaded weights cannot be loaded when used by `transformers` library users. (`mlx` loads them without a problem.)
**To Reproduce**
Run LoRA fine-tuning, then run fusing script:
```bash
!python -m mlx_lm.fuse \
--model google/gemma-7b-it \
--adapter-file checkpoints/600_adapters.npz \
--upload-repo alexweberk/gemma-7b-it-trismegistus \
--hf-path google/gemma-7b-it
```
After the upload, I tried running:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
repo_id = "alexweberk/gemma-7b-it-trismegistus"
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id)
model.to("mps")
input_text = format_prompt(system_prompt, question)
input_ids = tokenizer(input_text, return_tensors="pt").to("mps")
outputs = model.generate(
**input_ids,
max_new_tokens=256,
)
print(tokenizer.decode(outputs[0]))
```
Which gives the full error message below:
```
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[14], [line 7](vscode-notebook-cell:?execution_count=14&line=7)
[4](vscode-notebook-cell:?execution_count=14&line=4) repo_id = "alexweberk/gemma-7b-it-trismegistus"
[6](vscode-notebook-cell:?execution_count=14&line=6) tokenizer = AutoTokenizer.from_pretrained(repo_id)
----> [7](vscode-notebook-cell:?execution_count=14&line=7) model = AutoModelForCausalLM.from_pretrained(repo_id)
[8](vscode-notebook-cell:?execution_count=14&line=8) model.to('mps')
[10](vscode-notebook-cell:?execution_count=14&line=10) input_text = format_prompt(system_prompt, question)
File [~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:561](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:561), in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
[559](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:559) elif type(config) in cls._model_mapping.keys():
[560](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:560) model_class = _get_model_class(config, cls._model_mapping)
--> [561](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:561) return model_class.from_pretrained(
[562](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:562) pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
[563](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:563) )
[564](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:564) raise ValueError(
[565](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:565) f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
[566](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:566) f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
[567](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:567) )
File [~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3502](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3502), in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
[3493](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3493) if dtype_orig is not None:
[3494](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3494) torch.set_default_dtype(dtype_orig)
[3495](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3495) (
[3496](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3496) model,
[3497](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3497) missing_keys,
[3498](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3498) unexpected_keys,
[3499](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3499) mismatched_keys,
[3500](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3500) offload_index,
[3501](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3501) error_msgs,
-> [3502](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3502) ) = cls._load_pretrained_model(
[3503](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3503) model,
[3504](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3504) state_dict,
[3505](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3505) loaded_state_dict_keys, # XXX: rename?
[3506](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3506) resolved_archive_file,
[3507](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3507) pretrained_model_name_or_path,
[3508](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3508) ignore_mismatched_sizes=ignore_mismatched_sizes,
[3509](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3509) sharded_metadata=sharded_metadata,
[3510](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3510) _fast_init=_fast_init,
[3511](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3511) low_cpu_mem_usage=low_cpu_mem_usage,
[3512](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3512) device_map=device_map,
[3513](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3513) offload_folder=offload_folder,
[3514](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3514) offload_state_dict=offload_state_dict,
[3515](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3515) dtype=torch_dtype,
[3516](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3516) hf_quantizer=hf_quantizer,
[3517](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3517) keep_in_fp32_modules=keep_in_fp32_modules,
[3518](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3518) )
[3520](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3520) # make sure token embedding weights are still tied if needed
[3521](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3521) model.tie_weights()
File [~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3903](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3903), in PreTrainedModel._load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_modules)
[3901](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3901) if shard_file in disk_only_shard_files:
[3902](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3902) continue
-> [3903](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3903) state_dict = load_state_dict(shard_file)
[3905](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3905) # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
[3906](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3906) # matching the weights in the model.
[3907](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3907) mismatched_keys += _find_mismatched_keys(
[3908](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3908) state_dict,
[3909](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3909) model_state_dict,
(...)
[3913](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3913) ignore_mismatched_sizes,
[3914](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:3914) )
File [~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:507](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:507), in load_state_dict(checkpoint_file)
[505](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:505) with safe_open(checkpoint_file, framework="pt") as f:
[506](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:506) metadata = f.metadata()
--> [507](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:507) if metadata.get("format") not in ["pt", "tf", "flax"]:
[508](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:508) raise OSError(
[509](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:509) f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
[510](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:510) "you save your model with the `save_pretrained` method."
[511](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:511) )
[512](https://file+.vscode-resource.vscode-cdn.net/Users/alexishida/Projects/07_libraries/playing-with-llms/notebooks/mlx_gemma/~/miniforge3/envs/py311/lib/python3.11/site-packages/transformers/modeling_utils.py:512) return safe_load_file(checkpoint_file)
AttributeError: 'NoneType' object has no attribute 'get'
```
The error seems to stem from the safetensors files missing the metadata for {"format": "pt"} when they are loaded by `AutoModelForCausalLM.from_pretrained()`.
A quick work around was to separately resave the files one by one using the below script for each of the safetensors files, and then uploading them to Huggingface.
```
from safetensors import safe_open
from safetensors.torch import save_file
safetensor_path = "lora_fused_model/model-00001-of-00004.safetensors"
# ...
fname, ext = safetensor_path.split("/")[-1].split(".")
tensors = dict()
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
save_file(tensors, f"lora_fused_model/{fname}-with-format.{ext}", metadata={"format": "pt"})
```
However, it would be nice to be able to quickly upload and have the model available for a wider audience more easily.
The source code led me to `mx.save_safetensors()` which led me to file the issue on this repo.
https://github.com/ml-explore/mlx-examples/blob/47dd6bd17f3cc7ef95672ea16e443e58ce5eb1bf/llms/mlx_lm/utils.py#L479
**Expected behavior**
Since there are many `transformers` users in the ecosystem, it would be beneficial to be able to seamlessly train and upload model weights to Huggingface and have other users use them through `transformers`.
**Desktop (please complete the following information):**
- OS Version: [e.g. MacOS 14.3]
- MacBook Pro M3 Max 128GB
- mlx==0.4.0
- mlx-lm==0.0.13
- transformers==4.38.1