opened 04:09PM - 30 May 24 UTC
closed 12:25PM - 25 Jun 24 UTC
### System Info
Python 3.11.5
torch 2.3.0
tran…sformers 4.41.1
accelerate 0.30.1
```
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06 Driver Version: 545.23.06 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-80GB Off | 00000000:D6:00.0 Off | 0 |
| N/A 34C P0 61W / 400W | 4MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM4-80GB Off | 00000000:DA:00.0 Off | 0 |
| N/A 36C P0 60W / 400W | 4MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
```
### Reproduction
I want to fine tune Llama 3 70B with HF TRL. I am trying Accelerate, bitsandbytes quantization, mixed precision, FSDP on two GPUs with 80 GBs each.
Running this code
```
if rank == 0:
hf_model = LlamaForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-70B-Instruct",
load_in_8bit=True,
device_map="auto",
)
```
via this command:
```
> accelerate launch --config_file ./accelerate_fsdp_config.yaml train.py
```
At Slurm managed cluster inside sbatch script with:
```
#SBATCH --nodes=1
#SBATCH --gpus-per-node=2
```
with this fsdp config:
```
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
Results in:
> [rank1]: RuntimeError: Only Tensors of floating point and complex dtype can require gradients
the same for rank 0. The same when I do `load_in_8bit=True`.
Error:
```
[rank0]: Traceback (most recent call last):
[rank0]: File "/data/artyom_karpov/rl4steg/train.py", line 345, in <module>
[rank0]: train(context)
[rank0]: File "/data/artyom_karpov/rl4steg/train.py", line 83, in train
[rank0]: hf_model = LlamaForCausalLM.from_pretrained(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/artyom_karpov/rl4steg/.venv/lib/python3.11/site-packages/transformers/modeling_utils.py", line 3754, in from_pretrained
[rank0]: ) = cls._load_pretrained_model(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/artyom_karpov/rl4steg/.venv/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4214, in _load_pretrained_model
[rank0]: new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/artyom_karpov/rl4steg/.venv/lib/python3.11/site-packages/transformers/modeling_utils.py", line 896, in _load_state_dict_into_meta_model
[rank0]: value = type(value)(value.data.to("cpu"), **value.__dict__)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/artyom_karpov/rl4steg/.venv/lib/python3.11/site-packages/bitsandbytes/nn/modules.py", line 297, in __new__
[rank0]: return torch.Tensor._make_subclass(cls, data, requires_grad)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Only Tensors of floating point and complex dtype can require gradients
```
### Expected behavior
I expect the model to be loaded.