Feature Request: Add DDP Communication Hooks

Motivation

I would like to request the addition of DDP communication hooks to the accelerate library. This feature enhances performance in distributed training by allowing control over how gradients are communicated across workers. Frameworks like PyTorch Lightning and Detectron2 use these hooks to reduce communication overhead and speed up training. Adding this capability to accelerate would provide similar performance benefits to its users.

Feature Description

Introduce support for DDP communication hooks such as PowerSGD, FP16, and BF16 in the accelerate library. Users can select and apply these hooks to optimize gradient communication in their distributed training models.

Example Code Snippet

Here is an example of how this feature can be used in accelerate:

from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

ddp_kwargs = DistributedDataParallelKwargs(
    comm_hook=DDPCommunicationHookType.FP16,
)

accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
model = accelerator.prepare(MyModel())

# Training loop
for data in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()

For reference, here is how Detectron2 registers a DDP communication hook:

def create_ddp_model(model, *, fp16_compression=False, **kwargs):
    if comm.get_world_size() == 1:
        return model
    if "device_ids" not in kwargs:
        kwargs["device_ids"] = [comm.get_local_rank()]
    ddp = DistributedDataParallel(model, **kwargs)
    if fp16_compression:
        from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
        ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
    return ddp

Thank you for considering this feature request. This addition will help enhance distributed training efficiency in the accelerate library.

Can you open this in the accelerate repo please? :hugs: The forums are really just for answering q/a that arenā€™t direct bugs, and its hard for us to keep track of feature requests here

1 Like

I have opened a PR for this . Thanks for your response :slight_smile:

1 Like