Hi everybody!
I am having the same problem… quite frustrating because I am super excited about OPT-IML. Meta’s powerful alternative to OpenAI’s ChatGPT with a strong focus on responsible compute and a genuine commitment to open source.
The weights are out … and I can’t wait for it to be fully compatible with HuggingFace piplines.
I did not manage to get up and running for inference on multiple GPUs. OPT focuses on efficiency. So there is a lot of cutting-edge model parallelism going on and I suspect compatibility with HuggingFace modules and especially with Accelerate is currently broken. Let me share some of my research and blockers:
I am working on a AWS g5.48xlarge
(8x NVIDIA A10G) with CUDA 11.7
I did some tests with the smaller models, that should easily fit in GPU memory. I can successfully load the model distributed on multiple GPUs with the following code:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
MODEL_NAME = "facebook/opt-2.7b"
# load model with device_map="auto"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto").cuda()
But it fails at inference:
# the fast tokenizer currently does not work correctly
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
prompt = "Hello, I am conscious and"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
generated_ids = model.generate(input_ids)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 5
3 prompt = "Hello, I am conscious and"
4 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
----> 5 generated_ids = model.generate(input_ids)
File ~/huggingface/.venv/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File ~/huggingface/.venv/lib/python3.8/site-packages/transformers/generation/utils.py:1518, in GenerationMixin.generate(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, penalty_alpha, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, suppress_tokens, begin_suppress_tokens, forced_decoder_ids, **model_kwargs)
1513 raise ValueError(
1514 f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
1515 )
1517 # 10. run greedy search
-> 1518 return self.greedy_search(
1519 input_ids,
1520 logits_processor=logits_processor,
1521 stopping_criteria=stopping_criteria,
1522 pad_token_id=pad_token_id,
1523 eos_token_id=eos_token_id,
1524 output_scores=output_scores,
1525 return_dict_in_generate=return_dict_in_generate,
1526 synced_gpus=synced_gpus,
1527 **model_kwargs,
1528 )
1530 elif is_contrastive_search_gen_mode:
1532 if num_return_sequences > 1:
File ~/huggingface/.venv/lib/python3.8/site-packages/transformers/generation/utils.py:2285, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
2282 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2284 # forward pass to get next token
-> 2285 outputs = self(
2286 **model_inputs,
2287 return_dict=True,
2288 output_attentions=output_attentions,
2289 output_hidden_states=output_hidden_states,
2290 )
2292 if synced_gpus and this_peer_finished:
2293 continue # don't waste resources running the code we don't need
File ~/huggingface/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/huggingface/.venv/lib/python3.8/site-packages/accelerate/hooks.py:156, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
154 output = old_forward(*args, **kwargs)
155 else:
--> 156 output = old_forward(*args, **kwargs)
157 return module._hf_hook.post_forward(module, output)
File ~/huggingface/.venv/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:934, in OPTForCausalLM.forward(self, input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
931 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
933 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 934 outputs = self.model.decoder(
935 input_ids=input_ids,
936 attention_mask=attention_mask,
937 head_mask=head_mask,
938 past_key_values=past_key_values,
939 inputs_embeds=inputs_embeds,
940 use_cache=use_cache,
941 output_attentions=output_attentions,
942 output_hidden_states=output_hidden_states,
943 return_dict=return_dict,
944 )
946 logits = self.lm_head(outputs[0]).contiguous()
948 loss = None
File ~/huggingface/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/huggingface/.venv/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:698, in OPTDecoder.forward(self, input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
689 layer_outputs = torch.utils.checkpoint.checkpoint(
690 create_custom_forward(decoder_layer),
691 hidden_states,
(...)
694 None,
695 )
696 else:
--> 698 layer_outputs = decoder_layer(
699 hidden_states,
700 attention_mask=attention_mask,
701 layer_head_mask=(head_mask[idx] if head_mask is not None else None),
702 past_key_value=past_key_value,
703 output_attentions=output_attentions,
704 use_cache=use_cache,
705 )
707 hidden_states = layer_outputs[0]
709 if use_cache:
File ~/huggingface/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/huggingface/.venv/lib/python3.8/site-packages/accelerate/hooks.py:156, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
154 output = old_forward(*args, **kwargs)
155 else:
--> 156 output = old_forward(*args, **kwargs)
157 return module._hf_hook.post_forward(module, output)
File ~/huggingface/.venv/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:324, in OPTDecoderLayer.forward(self, hidden_states, attention_mask, layer_head_mask, output_attentions, use_cache, past_key_value)
322 # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
323 if self.do_layer_norm_before:
--> 324 hidden_states = self.self_attn_layer_norm(hidden_states)
326 # Self Attention
327 hidden_states, self_attn_weights, present_key_value = self.self_attn(
328 hidden_states=hidden_states,
329 past_key_value=past_key_value,
(...)
332 output_attentions=output_attentions,
333 )
File ~/huggingface/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/huggingface/.venv/lib/python3.8/site-packages/accelerate/hooks.py:156, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
154 output = old_forward(*args, **kwargs)
155 else:
--> 156 output = old_forward(*args, **kwargs)
157 return module._hf_hook.post_forward(module, output)
File ~/huggingface/.venv/lib/python3.8/site-packages/torch/nn/modules/normalization.py:190, in LayerNorm.forward(self, input)
189 def forward(self, input: Tensor) -> Tensor:
--> 190 return F.layer_norm(
191 input, self.normalized_shape, self.weight, self.bias, self.eps)
File ~/huggingface/.venv/lib/python3.8/site-packages/torch/nn/functional.py:2515, in layer_norm(input, normalized_shape, weight, bias, eps)
2511 if has_torch_function_variadic(input, weight, bias):
2512 return handle_torch_function(
2513 layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
2514 )
-> 2515 return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__native_layer_norm)
@sgugger would be awesome if you could shed some light on this! 
Other pointers
Hardware Issues
From the official metaseq repo
Where can I run this?
Right now only on Azure, as it requires the 80GB A100s. To enable it on other locations, we need to either try CPU offloading, or we need to use MP 16. FSDP should not be used because some workers will only be used for parameter hosting, and will not actually perform computations.
Are we stuck with 80GB A100s?
Using OPT-175B with Alpa
Alpa promises easier setup due to more flexible parallelisms on older generations of GPUs, such as 40GB A100, V100, T4, M60, etc.
But it only supports pip wheels for CUDA (cuDNN): 11.1 (8.0.5), 11.2 (8.1.0), 11.3 (8.2.0)
, which is kind of outdated.