Bug with model.generate if max_length or max_new_tokens are set, with accelerate deepspeed zero level 3

model.generate fails if max_length or max_new_tokens are set, with accelerate deepspeed zero level 3.

I use transformers.T5ModelForConditionalGeneration with google/t5-flan-* model, on a DGXA100 node (usually).

It seems that when a process finishes generating before the others (which almost always happens), the others get stuck waiting for it forever in a barrier. I was wondering if that was a known issue.

Everything works fine if each process has the same inputs, which makes sense as all processes finish at the same time. Somehow, everything also works fine if no value is passed for max_new_tokens or max_length and the default value of 20 is used.

accelerate               0.16.0
deepspeed                0.8.1
pytorch-triton           2.0.0+c8bfe3f548
torch                    1.12.1+cu113
torchaudio               0.12.1+cu113
torchtyping              0.1.4
torchvision              0.13.1+cu113
transformers             4.26.1

The end of the error message. You can see that rank 1 generates (and prints) its text, then rank 0 breaks at a barrier. When there are more processes, the same happens, one rank finishes first, then they all break at group._allgather_base(output_tensor, input_tensor), line 2136 of torch/distributed/distributed_c10d.py


[02/20/23 22:13:14] INFO     [1/2] __main__ - batch['input_ids'].shape = torch.Size([3, 238])                     test_accelerate.py:97
                    INFO     [1/2] __main__ - <accelerate.data_loader.DataLoaderShard object at 0x7f3c80115a00>   test_accelerate.py:99
                    INFO     [1/2] __main__ - {'input_ids': torch.Size([3, 238]), 'attention_mask':              test_accelerate.py:103
                             torch.Size([3, 238])}
                    INFO     [1/2] __main__ - dict_keys(['input_ids', 'attention_mask'])                         test_accelerate.py:105
                    INFO     [1/2] __main__ - max_new_tokens = 100                                               
[02/20/23 22:13:18] INFO     [1/2] __main__ - torch.Size([3, 31])                                                test_accelerate.py:113
                    INFO     [1/2] __main__ -   GENERATED TEXT:                                                              test_accelerate.py:114
                                     - amet labore voluptatem consectetur aliquam quiquia.</s>
                                     -  Sit adipisci neque tempora amet ipsum tempora aliquam.</s>
                                     -  etincidunt</s>
[E ProcessGroupGloo.cpp:2791] [Rank 0]: Rank 1 failed to pass monitoredBarrier in 1800000 ms
[E ProcessGroupGloo.cpp:136] [Rank 0]: Ranks 1 failed to pass monitoredBarrier in 1800000 ms
โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ /home/mila/g/gagnonju/Marg-Li-CoT/with_trlx/test_accelerate.py:121 in <module>                   โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   118                                                                                            โ”‚
โ”‚   119                                                                                            โ”‚
โ”‚   120 if __name__ == "__main__":                                                                 โ”‚
โ”‚ โฑ 121 โ”‚   fire.Fire(main)                                                                        โ”‚
โ”‚   122                                                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/fire/core.py:141 in Fire                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   138 โ”‚   context.update(caller_globals)                                                         โ”‚
โ”‚   139 โ”‚   context.update(caller_locals)                                                          โ”‚
โ”‚   140                                                                                            โ”‚
โ”‚ โฑ 141   component_trace = _Fire(component, args, parsed_flag_args, context, name)                โ”‚
โ”‚   142                                                                                            โ”‚
โ”‚   143   if component_trace.HasError():                                                           โ”‚
โ”‚   144 โ”‚   _DisplayError(component_trace)                                                         โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/fire/core.py:475 in _Fire                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   472 โ”‚     is_class = inspect.isclass(component)                                                โ”‚
โ”‚   473 โ”‚                                                                                          โ”‚
โ”‚   474 โ”‚     try:                                                                                 โ”‚
โ”‚ โฑ 475 โ”‚   โ”‚   component, remaining_args = _CallAndUpdateTrace(                                   โ”‚
โ”‚   476 โ”‚   โ”‚   โ”‚   component,                                                                     โ”‚
โ”‚   477 โ”‚   โ”‚   โ”‚   remaining_args,                                                                โ”‚
โ”‚   478 โ”‚   โ”‚   โ”‚   component_trace,                                                               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/fire/core.py:691 in _CallAndUpdateTrace  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   688 โ”‚   loop = asyncio.get_event_loop()                                                        โ”‚
โ”‚   689 โ”‚   component = loop.run_until_complete(fn(*varargs, **kwargs))                            โ”‚
โ”‚   690   else:                                                                                    โ”‚
โ”‚ โฑ 691 โ”‚   component = fn(*varargs, **kwargs)                                                     โ”‚
โ”‚   692                                                                                            โ”‚
โ”‚   693   if treatment == 'class':                                                                 โ”‚
โ”‚   694 โ”‚   action = trace.INSTANTIATED_CLASS                                                      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/Marg-Li-CoT/with_trlx/test_accelerate.py:109 in main                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   106 โ”‚   LOGGER.info(f"{max_new_tokens = }")                                                    โ”‚
โ”‚   107 โ”‚   a9r.wait_for_everyone()                                                                โ”‚
โ”‚   108 โ”‚   with torch.no_grad():                                                                  โ”‚
โ”‚ โฑ 109 โ”‚   โ”‚   output = model.generate(                                                           โ”‚
โ”‚   110 โ”‚   โ”‚   โ”‚   **batch,                                                                       โ”‚
โ”‚   111 โ”‚   โ”‚   โ”‚   max_length=max_new_tokens,                                                     โ”‚
โ”‚   112 โ”‚   โ”‚   )                                                                                  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27 in        โ”‚
โ”‚ decorate_context                                                                                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    24 โ”‚   โ”‚   @functools.wraps(func)                                                             โ”‚
โ”‚    25 โ”‚   โ”‚   def decorate_context(*args, **kwargs):                                             โ”‚
โ”‚    26 โ”‚   โ”‚   โ”‚   with self.clone():                                                             โ”‚
โ”‚ โฑ  27 โ”‚   โ”‚   โ”‚   โ”‚   return func(*args, **kwargs)                                               โ”‚
โ”‚    28 โ”‚   โ”‚   return cast(F, decorate_context)                                                   โ”‚
โ”‚    29 โ”‚                                                                                          โ”‚
โ”‚    30 โ”‚   def _wrap_generator(self, func):                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/transformers/generation/utils.py:1391 in โ”‚
โ”‚ generate                                                                                         โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1388 โ”‚   โ”‚   โ”‚   โ”‚   )                                                                         โ”‚
โ”‚   1389 โ”‚   โ”‚   โ”‚                                                                                 โ”‚
โ”‚   1390 โ”‚   โ”‚   โ”‚   # 11. run greedy search                                                       โ”‚
โ”‚ โฑ 1391 โ”‚   โ”‚   โ”‚   return self.greedy_search(                                                    โ”‚
โ”‚   1392 โ”‚   โ”‚   โ”‚   โ”‚   input_ids,                                                                โ”‚
โ”‚   1393 โ”‚   โ”‚   โ”‚   โ”‚   logits_processor=logits_processor,                                        โ”‚
โ”‚   1394 โ”‚   โ”‚   โ”‚   โ”‚   stopping_criteria=stopping_criteria,                                      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/transformers/generation/utils.py:2179 in โ”‚
โ”‚ greedy_search                                                                                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   2176 โ”‚   โ”‚   โ”‚   model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)  โ”‚
โ”‚   2177 โ”‚   โ”‚   โ”‚                                                                                 โ”‚
โ”‚   2178 โ”‚   โ”‚   โ”‚   # forward pass to get next token                                              โ”‚
โ”‚ โฑ 2179 โ”‚   โ”‚   โ”‚   outputs = self(                                                               โ”‚
โ”‚   2180 โ”‚   โ”‚   โ”‚   โ”‚   **model_inputs,                                                           โ”‚
โ”‚   2181 โ”‚   โ”‚   โ”‚   โ”‚   return_dict=True,                                                         โ”‚
โ”‚   2182 โ”‚   โ”‚   โ”‚   โ”‚   output_attentions=output_attentions,                                      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/torch/nn/modules/module.py:1137 in       โ”‚
โ”‚ _call_impl                                                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1134 โ”‚   โ”‚   โ”‚   full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()     โ”‚
โ”‚   1135 โ”‚   โ”‚   if _global_forward_pre_hooks or self._forward_pre_hooks:                          โ”‚
โ”‚   1136 โ”‚   โ”‚   โ”‚   for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.v  โ”‚
โ”‚ โฑ 1137 โ”‚   โ”‚   โ”‚   โ”‚   result = hook(self, input)                                                โ”‚
โ”‚   1138 โ”‚   โ”‚   โ”‚   โ”‚   if result is not None:                                                    โ”‚
โ”‚   1139 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   if not isinstance(result, tuple):                                     โ”‚
โ”‚   1140 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   result = (result,)                                                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/utils/nvtx.py:9 in wrapped_fn  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    6 โ”‚   function call."""                                                                       โ”‚
โ”‚    7 โ”‚   def wrapped_fn(*args, **kwargs):                                                        โ”‚
โ”‚    8 โ”‚   โ”‚   get_accelerator().range_push(func.__qualname__)                                     โ”‚
โ”‚ โฑ  9 โ”‚   โ”‚   ret_val = func(*args, **kwargs)                                                     โ”‚
โ”‚   10 โ”‚   โ”‚   get_accelerator().range_pop()                                                       โ”‚
โ”‚   11 โ”‚   โ”‚   return ret_val                                                                      โ”‚
โ”‚   12                                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload โ”‚
โ”‚ .py:348 in _pre_forward_module_hook                                                              โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   345 โ”‚   โ”‚                                                                                      โ”‚
โ”‚   346 โ”‚   โ”‚   @instrument_w_nvtx                                                                 โ”‚
โ”‚   347 โ”‚   โ”‚   def _pre_forward_module_hook(module, *args):                                       โ”‚
โ”‚ โฑ 348 โ”‚   โ”‚   โ”‚   self.pre_sub_module_forward_function(module)                                   โ”‚
โ”‚   349 โ”‚   โ”‚                                                                                      โ”‚
โ”‚   350 โ”‚   โ”‚   @instrument_w_nvtx                                                                 โ”‚
โ”‚   351 โ”‚   โ”‚   def _post_forward_module_hook(module, input, output):                              โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27 in        โ”‚
โ”‚ decorate_context                                                                                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    24 โ”‚   โ”‚   @functools.wraps(func)                                                             โ”‚
โ”‚    25 โ”‚   โ”‚   def decorate_context(*args, **kwargs):                                             โ”‚
โ”‚    26 โ”‚   โ”‚   โ”‚   with self.clone():                                                             โ”‚
โ”‚ โฑ  27 โ”‚   โ”‚   โ”‚   โ”‚   return func(*args, **kwargs)                                               โ”‚
โ”‚    28 โ”‚   โ”‚   return cast(F, decorate_context)                                                   โ”‚
โ”‚    29 โ”‚                                                                                          โ”‚
โ”‚    30 โ”‚   def _wrap_generator(self, func):                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload โ”‚
โ”‚ .py:478 in pre_sub_module_forward_function                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   475 โ”‚   โ”‚   param_coordinator.trace_prologue(sub_module)                                       โ”‚
โ”‚   476 โ”‚   โ”‚   if param_coordinator.is_record_trace():                                            โ”‚
โ”‚   477 โ”‚   โ”‚   โ”‚   param_coordinator.record_module(sub_module)                                    โ”‚
โ”‚ โฑ 478 โ”‚   โ”‚   param_coordinator.fetch_sub_module(sub_module)                                     โ”‚
โ”‚   479 โ”‚   โ”‚                                                                                      โ”‚
โ”‚   480 โ”‚   โ”‚   see_memory_usage(                                                                  โ”‚
โ”‚   481 โ”‚   โ”‚   โ”‚   f"Before sub module function {sub_module.__class__.__name__} after fetch",     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/utils/nvtx.py:9 in wrapped_fn  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    6 โ”‚   function call."""                                                                       โ”‚
โ”‚    7 โ”‚   def wrapped_fn(*args, **kwargs):                                                        โ”‚
โ”‚    8 โ”‚   โ”‚   get_accelerator().range_push(func.__qualname__)                                     โ”‚
โ”‚ โฑ  9 โ”‚   โ”‚   ret_val = func(*args, **kwargs)                                                     โ”‚
โ”‚   10 โ”‚   โ”‚   get_accelerator().range_pop()                                                       โ”‚
โ”‚   11 โ”‚   โ”‚   return ret_val                                                                      โ”‚
โ”‚   12                                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27 in        โ”‚
โ”‚ decorate_context                                                                                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    24 โ”‚   โ”‚   @functools.wraps(func)                                                             โ”‚
โ”‚    25 โ”‚   โ”‚   def decorate_context(*args, **kwargs):                                             โ”‚
โ”‚    26 โ”‚   โ”‚   โ”‚   with self.clone():                                                             โ”‚
โ”‚ โฑ  27 โ”‚   โ”‚   โ”‚   โ”‚   return func(*args, **kwargs)                                               โ”‚
โ”‚    28 โ”‚   โ”‚   return cast(F, decorate_context)                                                   โ”‚
โ”‚    29 โ”‚                                                                                          โ”‚
โ”‚    30 โ”‚   def _wrap_generator(self, func):                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param โ”‚
โ”‚ _coordinator.py:349 in fetch_sub_module                                                          โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   346 โ”‚   โ”‚   โ”‚   โ”‚                                                                              โ”‚
โ”‚   347 โ”‚   โ”‚   โ”‚   โ”‚   for param in params_to_prefetch:                                           โ”‚
โ”‚   348 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   debug_rank0(f"-prefetch: {param.ds_summary()}")                        โ”‚
โ”‚ โฑ 349 โ”‚   โ”‚   โ”‚   โ”‚   self.__all_gather_params(params_to_prefetch)                               โ”‚
โ”‚   350 โ”‚   โ”‚   โ”‚   โ”‚                                                                              โ”‚
โ”‚   351 โ”‚   โ”‚   โ”‚   โ”‚   if self.__prefetch_nvme:                                                   โ”‚
โ”‚   352 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   self.__prefetch_nvme_param_partitions()                                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/utils/nvtx.py:9 in wrapped_fn  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    6 โ”‚   function call."""                                                                       โ”‚
โ”‚    7 โ”‚   def wrapped_fn(*args, **kwargs):                                                        โ”‚
โ”‚    8 โ”‚   โ”‚   get_accelerator().range_push(func.__qualname__)                                     โ”‚
โ”‚ โฑ  9 โ”‚   โ”‚   ret_val = func(*args, **kwargs)                                                     โ”‚
โ”‚   10 โ”‚   โ”‚   get_accelerator().range_pop()                                                       โ”‚
โ”‚   11 โ”‚   โ”‚   return ret_val                                                                      โ”‚
โ”‚   12                                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param โ”‚
โ”‚ _coordinator.py:399 in __all_gather_params                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   396 โ”‚   โ”‚                                                                                      โ”‚
โ”‚   397 โ”‚   โ”‚   if partitioned_params:                                                             โ”‚
โ”‚   398 โ”‚   โ”‚   โ”‚   with get_accelerator().stream(self.__allgather_stream):                        โ”‚
โ”‚ โฑ 399 โ”‚   โ”‚   โ”‚   โ”‚   handle = partitioned_params[0].all_gather_coalesced(partitioned_params)    โ”‚
โ”‚   400 โ”‚   โ”‚   โ”‚                                                                                  โ”‚
โ”‚   401 โ”‚   โ”‚   โ”‚   for param in partitioned_params:                                               โ”‚
โ”‚   402 โ”‚   โ”‚   โ”‚   โ”‚   assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary()     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/utils/nvtx.py:9 in wrapped_fn  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    6 โ”‚   function call."""                                                                       โ”‚
โ”‚    7 โ”‚   def wrapped_fn(*args, **kwargs):                                                        โ”‚
โ”‚    8 โ”‚   โ”‚   get_accelerator().range_push(func.__qualname__)                                     โ”‚
โ”‚ โฑ  9 โ”‚   โ”‚   ret_val = func(*args, **kwargs)                                                     โ”‚
โ”‚   10 โ”‚   โ”‚   get_accelerator().range_pop()                                                       โ”‚
โ”‚   11 โ”‚   โ”‚   return ret_val                                                                      โ”‚
โ”‚   12                                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_paramet โ”‚
โ”‚ ers.py:876 in all_gather_coalesced                                                               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    873 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   for p in params                                                       โ”‚
โ”‚    874 โ”‚   โ”‚   โ”‚   โ”‚   ],                                                                        โ”‚
โ”‚    875 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚    out=partitions[self.rank])                   โ”‚
โ”‚ โฑ  876 โ”‚   โ”‚   โ”‚   โ”‚   handle = _dist_allgather_fn(partitions[self.rank],                        โ”‚
โ”‚    877 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   flat_tensor,                                  โ”‚
โ”‚    878 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   self.ds_process_group)                        โ”‚
โ”‚    879                                                                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_paramet โ”‚
โ”‚ ers.py:43 in _dist_allgather_fn                                                                  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚     40                                                                                           โ”‚
โ”‚     41                                                                                           โ”‚
โ”‚     42 def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None):          โ”‚
โ”‚ โฑ   43 โ”‚   return instrument_w_nvtx(dist.allgather_fn)(output_tensor,                            โ”‚
โ”‚     44 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   input_tensor,                             โ”‚
โ”‚     45 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   group=group,                              โ”‚
โ”‚     46 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   async_op=True)                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/utils/nvtx.py:9 in wrapped_fn  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    6 โ”‚   function call."""                                                                       โ”‚
โ”‚    7 โ”‚   def wrapped_fn(*args, **kwargs):                                                        โ”‚
โ”‚    8 โ”‚   โ”‚   get_accelerator().range_push(func.__qualname__)                                     โ”‚
โ”‚ โฑ  9 โ”‚   โ”‚   ret_val = func(*args, **kwargs)                                                     โ”‚
โ”‚   10 โ”‚   โ”‚   get_accelerator().range_pop()                                                       โ”‚
โ”‚   11 โ”‚   โ”‚   return ret_val                                                                      โ”‚
โ”‚   12                                                                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/comm/comm.py:340 in            โ”‚
โ”‚ allgather_fn                                                                                     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   337 โ”‚   global has_warned_all_gather                                                           โ”‚
โ”‚   338 โ”‚   assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please    โ”‚
โ”‚   339 โ”‚   if cdb.has_allgather_base:                                                             โ”‚
โ”‚ โฑ 340 โ”‚   โ”‚   return all_gather_base(output_tensor,                                              โ”‚
โ”‚   341 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚      input_tensor,                                               โ”‚
โ”‚   342 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚      group=group,                                                โ”‚
โ”‚   343 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚      async_op=async_op,                                          โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/comm/comm.py:127 in            โ”‚
โ”‚ log_wrapper                                                                                      โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   124 โ”‚   โ”‚   โ”‚   โ”‚   timers(log_name).start()                                                   โ”‚
โ”‚   125 โ”‚   โ”‚   # Return the op, then stop the op's timer                                          โ”‚
โ”‚   126 โ”‚   โ”‚   try:                                                                               โ”‚
โ”‚ โฑ 127 โ”‚   โ”‚   โ”‚   return func(*args, **kwargs)                                                   โ”‚
โ”‚   128 โ”‚   โ”‚   finally:                                                                           โ”‚
โ”‚   129 โ”‚   โ”‚   โ”‚   if comms_logger.enabled:                                                       โ”‚
โ”‚   130 โ”‚   โ”‚   โ”‚   โ”‚   # Need to make op blocking for accurate logging                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/comm/comm.py:318 in            โ”‚
โ”‚ all_gather_base                                                                                  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   315 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   log_name='all_gather_base',                                            โ”‚
โ”‚   316 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   debug=get_caller_func()):                                              โ”‚
โ”‚   317 โ”‚   global cdb                                                                             โ”‚
โ”‚ โฑ 318 โ”‚   return cdb.all_gather_base(output_tensor=output_tensor,                                โ”‚
โ”‚   319 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚      input_tensor=tensor,                                        โ”‚
โ”‚   320 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚      group=group,                                                โ”‚
โ”‚   321 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   โ”‚      async_op=async_op)                                          โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/deepspeed/comm/torch.py:83 in            โ”‚
โ”‚ all_gather_base                                                                                  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    80 โ”‚                                                                                          โ”‚
โ”‚    81 โ”‚   def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):    โ”‚
โ”‚    82 โ”‚   โ”‚   if self.has_allgather_base:                                                        โ”‚
โ”‚ โฑ  83 โ”‚   โ”‚   โ”‚   return torch.distributed.distributed_c10d._all_gather_base(                    โ”‚
โ”‚    84 โ”‚   โ”‚   โ”‚   โ”‚   output_tensor=output_tensor,                                               โ”‚
โ”‚    85 โ”‚   โ”‚   โ”‚   โ”‚   input_tensor=input_tensor,                                                 โ”‚
โ”‚    86 โ”‚   โ”‚   โ”‚   โ”‚   group=group,                                                               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py:21 โ”‚
โ”‚ 36 in _all_gather_base                                                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   2133 โ”‚   โ”‚   default_pg = _get_default_group()                                                 โ”‚
โ”‚   2134 โ”‚   โ”‚   work = default_pg._allgather_base(output_tensor, input_tensor)                    โ”‚
โ”‚   2135 โ”‚   else:                                                                                 โ”‚
โ”‚ โฑ 2136 โ”‚   โ”‚   work = group._allgather_base(output_tensor, input_tensor)                         โ”‚
โ”‚   2137 โ”‚                                                                                         โ”‚
โ”‚   2138 โ”‚   if async_op:                                                                          โ”‚
โ”‚   2139 โ”‚   โ”‚   return work                                                                       โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
RuntimeError: [Rank 0]: Ranks 1 failed to pass monitoredBarrier in 1800000 ms
[22:13:22] ERROR    failed (exitcode: 1) local_rank: 0 (pid: 622774) of binary: /home/mila/g/gagnonju/.main/bin/python

cc @smangrul

add synced_gpus=True to model.generate() params. Check the DeepSpeed section of Launch Configuration tab in the interactive example code explorer tool for more details: Learning how to incorporate :hugs: Accelerate features quickly! (huggingface.co)