I am using PEFT code to fine-tune a model while I use accelerate with bf16
to reduce the memory usage. When I call accelerate.clip_grad_norm_(model.parameters(), max_norm=1)
I am getting ValueError: Requires uniform dtype across all gradients but got {torch.bfloat16, torch.float32}
error as shown below:
File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/contextlib.py", line 131, in __exit__
self.gen.throw(type, value, traceback)
File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/site-packages/accelerate/accelerator.py", line 886, in accumulate
yield
File "run_peft_accelerate_fsdp.py", line 294, in main
accelerator.clip_grad_norm_(model.parameters(), max_norm=torch.tensor(1, dtype=config.torch_dtype))
File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/site-packages/accelerate/accelerator.py", line 1812, in clip_grad_norm_
return model.clip_grad_norm_(max_norm, norm_type)
File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1042, in clip_grad_norm_
local_sharded_norm = _get_grad_norm(sharded_params, norm_type).to(
File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1983, in _get_grad_norm
raise ValueError(
ValueError: Requires uniform dtype across all gradients but got {torch.bfloat16, torch.float32}
As shown, the error actually happens in fsdp._get_grad_norm()
function where it observes more than one dtype
on the gradients:
params_with_grad = [param for param in params if param.grad is not None]
if len(params_with_grad) == 0:
return torch.tensor(0.0)
grads = [param.grad for param in params_with_grad]
grad_dtypes = {grad.dtype for grad in grads}
if len(grad_dtypes) != 1:
raise ValueError(
f"Requires uniform dtype across all gradients but got {grad_dtypes}"
)
It seems that it is a bug, given that there is no way for a developer to fix the issue when bf16
is being used. So, was wondering how can I fix this.