Can someone explain how to fix a problem I am facing with safetensors saving? It seems something is being done with base_model in the original model that I need to know how to replicate. The error can be produced with the following:
model = AutoModelForAudioClassification.from_pretrained(
"facebook/wav2vec2-base", num_labels=num_labels, label2id=label2id, id2label=id2label
)
print('Save 1')
save_file(model.state_dict(), 'temp')
print('Save 1 Complete')
model.base_model = model.base_model
print('Save 2')
save_file(model.state_dict(), 'temp')
print('Save 2 Complete')
This outputs:
Save 1
Save 1 Complete
Save 2
Traceback (most recent call last):
.....
RuntimeError:
Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'wav2vec2.masked_spec_embed', 'base_model.masked_spec_embed'},
.....
'base_model.encoder.layers.11.final_layer_norm.bias'}].
A potential way to correctly save your model is to use `save_model`.
More information at https://huggingface.co/docs/safetensors/torch_shared_tensors
My only guess is that there is something being flagged with base_model to tell safetensors to ignore it, but when I reset the variable that flag is getting deleted.
1 Like
This is the easiest way to get around it, but it’s just a workaround.
opened 12:55AM - 24 Apr 24 UTC
closed 08:51AM - 24 Apr 24 UTC
solved
### Reminder
- [X] I have read the README and searched the existing issues.
##… # Reproduction
WANDB_DISABLED=1 NCCL_P2P_DISABLE=1 NCCL_IB_DISABLE=1 deepspeed --num_gpus 2 --master_port=9527 /workspace/projects/LLaMA-Factory/src/train_bash.py \
--stage rm \
--do_train \
--deepspeed xxxxxxxx/ds_z3_offload_config.json \
--model_name_or_path xxxxxxx/chatglm3-6b \
--adapter_name_or_path /xxx/chatglm_exp_sft_lora_llamafactory \
--create_new_adapter \
--dataset comparison_gpt4_zh \
--dataset_dir xxx/data \
--template chatglm3 \
--finetuning_type lora \
--lora_target query_key_value \
--output_dir xxx/chatglm_exp_rm_lora_llamafactory \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 4 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 10 \
--eval_steps 20 \
--evaluation_strategy steps \
--learning_rate 1e-5 \
--num_train_epochs 2.0 \
--max_samples 5000 \
--val_size 0.1 \
--plot_loss \
--fp16
可以正常训练,但是保存checkpoint时候提示如下错误:
[INFO|trainer.py:3305] 2024-04-23 16:56:46,579 >> Saving model checkpoint to /workspace/models/huggingface/chatglm32k_rm_sft_lora_llamafactory/checkpoint-10
[INFO|trainer.py:3319] 2024-04-23 16:56:46,587 >> Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Traceback (most recent call last):
File "/workspace/projects/LLaMA-Factory/src/train_bash.py", line 14, in <module>
main()
File "/workspace/projects/LLaMA-Factory/src/train_bash.py", line 5, in main
run_exp()
File "/workspace/projects/LLaMA-Factory/src/llmtuner/train/tuner.py", line 35, in run_exp
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
File "/workspace/projects/LLaMA-Factory/src/llmtuner/train/rm/workflow.py", line 50, in run_rm
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
return inner_training_loop(
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2278, in _inner_training_loop
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2673, in _maybe_log_save_evaluate
self._save_checkpoint(model, trial, metrics=metrics)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2752, in _save_checkpoint
self.save_model(output_dir, _internal_call=True)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3239, in save_model
self._save(output_dir, state_dict=state_dict)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3321, in _save
safetensors.torch.save_file(
File "/opt/conda/lib/python3.10/site-packages/safetensors/torch.py", line 284, in save_file
serialize_file(_flatten(tensors), filename, metadata=metadata)
File "/opt/conda/lib/python3.10/site-packages/safetensors/torch.py", line 480, in _flatten
raise RuntimeError(
RuntimeError:
Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'pretrained_model.base_model.model.lm_head.weight', 'pretrained_model.base_model.model.transformer.output_layer.weight'}].
A potential way to correctly save your model is to use `save_model`.
More information at https://huggingface.co/docs/safetensors/torch_shared_tensors
### Expected behavior
能够保存rm的checkpoints并顺利完成训练
### System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
- `transformers` version: 4.40.0
- Platform: Linux-5.15.0-101-generic-x86_64-with-glibc2.31
- Python version: 3.10.11
- Huggingface_hub version: 0.22.2
- Safetensors version: 0.4.3
- Accelerate version: 0.29.3
- Accelerate config: not found
- PyTorch version (GPU?): 2.0.1 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
### Others
无
–save_safetensors False
Thanks for the reply, that just causes huggingface to use torch.save instead of safetensors.save_file correct? That has worked for me in other cases but this model seems to have parameterized modules and gives the following error with torch.save
RuntimeError: Serialization of parametrized modules is only supported
through state_dict(). See:
https://pytorch.org/tutorials/beginner/saving_loading_models.html#
saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training
Do you know of any methods other than this workaround?
1 Like