Error in clip_grad_norm_ for bf16 via PEFT

I am using PEFT code to fine-tune a model while I use accelerate with bf16 to reduce the memory usage. When I call accelerate.clip_grad_norm_(model.parameters(), max_norm=1) I am getting ValueError: Requires uniform dtype across all gradients but got {torch.bfloat16, torch.float32} error as shown below:

  File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/contextlib.py", line 131, in __exit__
    self.gen.throw(type, value, traceback)
  File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/site-packages/accelerate/accelerator.py", line 886, in accumulate
    yield
  File "run_peft_accelerate_fsdp.py", line 294, in main
    accelerator.clip_grad_norm_(model.parameters(), max_norm=torch.tensor(1, dtype=config.torch_dtype))
  File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/site-packages/accelerate/accelerator.py", line 1812, in clip_grad_norm_
    return model.clip_grad_norm_(max_norm, norm_type)
  File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1042, in clip_grad_norm_
    local_sharded_norm = _get_grad_norm(sharded_params, norm_type).to(
  File "/home/user23/miniconda3/envs/newllmpeft2/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1983, in _get_grad_norm
    raise ValueError(
ValueError: Requires uniform dtype across all gradients but got {torch.bfloat16, torch.float32}

As shown, the error actually happens in fsdp._get_grad_norm() function where it observes more than one dtype on the gradients:

    params_with_grad = [param for param in params if param.grad is not None]
    if len(params_with_grad) == 0:
        return torch.tensor(0.0)
    grads = [param.grad for param in params_with_grad]
    grad_dtypes = {grad.dtype for grad in grads}
    if len(grad_dtypes) != 1:
        raise ValueError(
            f"Requires uniform dtype across all gradients but got {grad_dtypes}"
        )

It seems that it is a bug, given that there is no way for a developer to fix the issue when bf16 is being used. So, was wondering how can I fix this.

replied here: Error in clip_grad_norm_ for bf16 via PEFT · Issue #1628 · huggingface/accelerate (github.com)

1 Like