Error making predictions using LMM (LLaVA) model on multiple GPUs

Hi everyone,

I’m having difficulties trying to perform inferences with a distributed deep learning model across multiple GPUs using LLaVA. I’ve configured the device_map parameter to automatically distribute the model’s layers across each available GPU, but I encounter an error when executing the generate method of the model. Code details Below.

import torch
from PIL import Image
from transformers import LlavaProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig

# Load model an processor
checkpoint = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
    pretrained_model_name_or_path = checkpoint,
    return_dict_in_generate = True,
    device_map = 'balanced'
)
processor = LlavaProcessor.from_pretrained(
    pretrained_model_name_or_path = checkpoint,
)

# Inputs Stage
image = Image.open("test/00000001.jpg")
prompt = "USER: <image>\n{} ASSISTANT:"

inputs = processor(
    text = prompt.format("Describe this image for me"),
    images = image,
    padding = True,
    return_tensors = "pt"
).to("cuda:0")

# Prediction Stage
_ = model.eval()
with torch.no_grad():
    outputs = model.generate(
        input_ids       = inputs['input_ids'],
        attention_mask  = inputs['attention_mask'],
        pixel_values    = inputs['pixel_values'],
        max_new_tokens  = 128,
        pad_token_id    = processor.tokenizer.eos_token_id,
    )

# Decode Stage
pred = processor.batch_decode(
    outputs.sequences, 
    skip_special_tokens=True, 
    clean_up_tokenization_spaces=False
)[0]

print(pred)

I’ve verified the allocation of each layer of the model on the different GPUs (in my case, I’m using 2):

In [5]: model.hf_device_map
Out[5]: 
{'vision_tower': 0,
 'multi_modal_projector': 0,
 'language_model.model.embed_tokens': 0,
 'language_model.model.layers.0': 0,
 'language_model.model.layers.1': 0,
 'language_model.model.layers.2': 0,
 'language_model.model.layers.3': 0,
 'language_model.model.layers.4': 0,
 'language_model.model.layers.5': 0,
 'language_model.model.layers.6': 0,
 'language_model.model.layers.7': 0,
 'language_model.model.layers.8': 0,
 'language_model.model.layers.9': 0,
 'language_model.model.layers.10': 0,
 'language_model.model.layers.11': 0,
 'language_model.model.layers.12': 0,
 'language_model.model.layers.13': 0,
 'language_model.model.layers.14': 0,
 'language_model.model.layers.15': 1,
 'language_model.model.layers.16': 1,
 'language_model.model.layers.17': 1,
 'language_model.model.layers.18': 1,
 'language_model.model.layers.19': 1,
 'language_model.model.layers.20': 1,
 'language_model.model.layers.21': 1,
 'language_model.model.layers.22': 1,
 'language_model.model.layers.23': 1,
 'language_model.model.layers.24': 1,
 'language_model.model.layers.25': 1,
 'language_model.model.layers.26': 1,
 'language_model.model.layers.27': 1,
 'language_model.model.layers.28': 1,
 'language_model.model.layers.29': 1,
 'language_model.model.layers.30': 1,
 'language_model.model.layers.31': 1,
 'language_model.model.norm': 1,
 'language_model.lm_head': 1}

given the same, I made sure that the input tensors are on the same device as the initial layer of the model ('cuda:0'). I’ve also tried more generic assignments like to('cuda'), as well as different options for the device_map parameter such as 'balanced', 'sequential', and 'auto', but the problem persists, yielding the following error:

...
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument tensors in method wrapper_CUDA_cat)

It’s important to mention that my goal is to execute the model with full precision (fp32). If I quantize the model to 4, 8, or 16 bits, I can execute it without issues on a single GPU.

I would greatly appreciate any suggestions or shared experiences from those who have encountered a similar problem.

Below, I’ll leave the complete details of the error output.

Thanks and regards!

RuntimeError                              Traceback (most recent call last)
Cell In[5], line 36
     34 _ = model.eval()
     35 with torch.no_grad():
---> 36     outputs = model.generate(
     37         input_ids       = inputs['input_ids'],
     38         attention_mask  = inputs['attention_mask'],
     39         pixel_values    = inputs['pixel_values'],
     40         max_new_tokens  = 128,
     41         pad_token_id    = processor.tokenizer.eos_token_id,
     42     )
     44 pred = processor.batch_decode(
     45     outputs.sequences, 
     46     skip_special_tokens=True, 
     47     clean_up_tokenization_spaces=False
     48 )[0]
     50 print(pred)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/transformers/generation/utils.py:1718, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1701     return self.assisted_decoding(
   1702         input_ids,
   1703         assistant_model=assistant_model,
   (...)
   1714         **model_kwargs,
   1715     )
   1716 if generation_mode == GenerationMode.GREEDY_SEARCH:
   1717     # 11. run greedy search
-> 1718     return self.greedy_search(
   1719         input_ids,
   1720         logits_processor=logits_processor,
   1721         stopping_criteria=stopping_criteria,
   1722         pad_token_id=generation_config.pad_token_id,
   1723         eos_token_id=generation_config.eos_token_id,
   1724         output_scores=generation_config.output_scores,
   1725         return_dict_in_generate=generation_config.return_dict_in_generate,
   1726         synced_gpus=synced_gpus,
   1727         streamer=streamer,
   1728         **model_kwargs,
   1729     )
   1731 elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
   1732     if not model_kwargs["use_cache"]:

File ~/micromamba/envs/ml/lib/python3.9/site-packages/transformers/generation/utils.py:2579, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2576 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2578 # forward pass to get next token
-> 2579 outputs = self(
   2580     **model_inputs,
   2581     return_dict=True,
   2582     output_attentions=output_attentions,
   2583     output_hidden_states=output_hidden_states,
   2584 )
   2586 if synced_gpus and this_peer_finished:
   2587     continue  # don't waste resources running the code we don't need

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/micromamba/envs/ml/lib/python3.9/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/transformers/models/llava/modeling_llava.py:433, in LlavaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, vision_feature_layer, vision_feature_select_strategy, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    430             attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
    431             position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
--> 433 outputs = self.language_model(
    434     attention_mask=attention_mask,
    435     position_ids=position_ids,
    436     past_key_values=past_key_values,
    437     inputs_embeds=inputs_embeds,
    438     use_cache=use_cache,
    439     output_attentions=output_attentions,
    440     output_hidden_states=output_hidden_states,
    441     return_dict=return_dict,
    442 )
    444 logits = outputs[0]
    446 loss = None

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/micromamba/envs/ml/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:1174, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1171 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1173 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1174 outputs = self.model(
   1175     input_ids=input_ids,
   1176     attention_mask=attention_mask,
   1177     position_ids=position_ids,
   1178     past_key_values=past_key_values,
   1179     inputs_embeds=inputs_embeds,
   1180     use_cache=use_cache,
   1181     output_attentions=output_attentions,
   1182     output_hidden_states=output_hidden_states,
   1183     return_dict=return_dict,
   1184 )
   1186 hidden_states = outputs[0]
   1187 if self.config.pretraining_tp > 1:

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/micromamba/envs/ml/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:1061, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
   1051     layer_outputs = self._gradient_checkpointing_func(
   1052         decoder_layer.__call__,
   1053         hidden_states,
   (...)
   1058         use_cache,
   1059     )
   1060 else:
-> 1061     layer_outputs = decoder_layer(
   1062         hidden_states,
   1063         attention_mask=attention_mask,
   1064         position_ids=position_ids,
   1065         past_key_value=past_key_values,
   1066         output_attentions=output_attentions,
   1067         use_cache=use_cache,
   1068     )
   1070 hidden_states = layer_outputs[0]
   1072 if use_cache:

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/micromamba/envs/ml/lib/python3.9/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:789, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    786 hidden_states = self.input_layernorm(hidden_states)
    788 # Self Attention
--> 789 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    790     hidden_states=hidden_states,
    791     attention_mask=attention_mask,
    792     position_ids=position_ids,
    793     past_key_value=past_key_value,
    794     output_attentions=output_attentions,
    795     use_cache=use_cache,
    796     **kwargs,
    797 )
    798 hidden_states = residual + hidden_states
    800 # Fully Connected

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/micromamba/envs/ml/lib/python3.9/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py:408, in LlamaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    406 if past_key_value is not None:
    407     cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
--> 408     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    410 key_states = repeat_kv(key_states, self.num_key_value_groups)
    411 value_states = repeat_kv(value_states, self.num_key_value_groups)

File ~/micromamba/envs/ml/lib/python3.9/site-packages/transformers/cache_utils.py:127, in DynamicCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
    125     self.value_cache.append(value_states)
    126 else:
--> 127     self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
    128     self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
    130 return self.key_cache[layer_idx], self.value_cache[layer_idx]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument tensors in method wrapper_CUDA_cat)