Device_map="auto" with error: Expected all tensors to be on the same device

I’m trying to go over the tutorial Pipelines for inference, using a multi-GPU instance “g4dn.12xlarge”. This works fine when I set set the device_id=0, but when I tried to use device_map="auto", I got “Expected all tensors to be on the same device” error.

Here’s the code I am running:

from transformers import pipeline
generator = pipeline(model="openai/whisper-large", device_map="auto")
generator(
    [
        "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac",
        "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac",
    ]
)

And here’s the error stack:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:2!


File <command-2351028218258033>:1
----> 1 generator(
      2     [
      3         "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac",
      4         "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac",
      5     ]
      6 )

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/automatic_speech_recognition.py:378, in AutomaticSpeechRecognitionPipeline.__call__(self, inputs, **kwargs)
    331 def __call__(
    332     self,
    333     inputs: Union[np.ndarray, bytes, str],
    334     **kwargs,
    335 ):
    336     """
    337     Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`]
    338     documentation for more information.
   (...)
    376                     `"".join(chunk["text"] for chunk in output["chunks"])`.
    377     """
--> 378     return super().__call__(inputs, **kwargs)

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/base.py:1065, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1061 if can_use_iterator:
   1062     final_iterator = self.get_iterator(
   1063         inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
   1064     )
-> 1065     outputs = [output for output in final_iterator]
   1066     return outputs
   1067 else:

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/base.py:1065, in <listcomp>(.0)
   1061 if can_use_iterator:
   1062     final_iterator = self.get_iterator(
   1063         inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
   1064     )
-> 1065     outputs = [output for output in final_iterator]
   1066     return outputs
   1067 else:

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/pt_utils.py:124, in PipelineIterator.__next__(self)
    121     return self.loader_batch_item()
    123 # We're out of items within a batch
--> 124 item = next(self.iterator)
    125 processed = self.infer(item, **self.params)
    126 # We now have a batch of "inferred things".

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/pt_utils.py:266, in PipelinePackIterator.__next__(self)
    263             return accumulator
    265 while not is_last:
--> 266     processed = self.infer(next(self.iterator), **self.params)
    267     if self.loader_batch_size is not None:
    268         if isinstance(processed, torch.Tensor):

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/base.py:992, in Pipeline.forward(self, model_inputs, **forward_params)
    990     with inference_context():
    991         model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
--> 992         model_outputs = self._forward(model_inputs, **forward_params)
    993         model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
    994 else:

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/automatic_speech_recognition.py:562, in AutomaticSpeechRecognitionPipeline._forward(self, model_inputs, return_timestamps, generate_kwargs)
    560 elif self.type == "seq2seq_whisper":
    561     stride = model_inputs.pop("stride", None)
--> 562     tokens = self.model.generate(
    563         input_features=model_inputs.pop("input_features"),
    564         logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None,
    565         **generate_kwargs,
    566     )
    567     out = {"tokens": tokens}
    568     if stride is not None:

File /databricks/python/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/generation/utils.py:1391, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, **kwargs)
   1385         raise ValueError(
   1386             f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
   1387             " greedy search."
   1388         )
   1390     # 11. run greedy search
-> 1391     return self.greedy_search(
   1392         input_ids,
   1393         logits_processor=logits_processor,
   1394         stopping_criteria=stopping_criteria,
   1395         pad_token_id=generation_config.pad_token_id,
   1396         eos_token_id=generation_config.eos_token_id,
   1397         output_scores=generation_config.output_scores,
   1398         return_dict_in_generate=generation_config.return_dict_in_generate,
   1399         synced_gpus=synced_gpus,
   1400         **model_kwargs,
   1401     )
   1403 elif is_contrastive_search_gen_mode:
   1404     if generation_config.num_return_sequences > 1:

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/generation/utils.py:2179, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
   2176 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2178 # forward pass to get next token
-> 2179 outputs = self(
   2180     **model_inputs,
   2181     return_dict=True,
   2182     output_attentions=output_attentions,
   2183     output_hidden_states=output_hidden_states,
   2184 )
   2186 if synced_gpus and this_peer_finished:
   2187     continue  # don't waste resources running the code we don't need

File /databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/accelerate/hooks.py:158, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    156         output = old_forward(*args, **kwargs)
    157 else:
--> 158     output = old_forward(*args, **kwargs)
    159 return module._hf_hook.post_forward(module, output)

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:1196, in WhisperForConditionalGeneration.forward(self, input_features, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1191     if decoder_input_ids is None and decoder_inputs_embeds is None:
   1192         decoder_input_ids = shift_tokens_right(
   1193             labels, self.config.pad_token_id, self.config.decoder_start_token_id
   1194         )
-> 1196 outputs = self.model(
   1197     input_features,
   1198     decoder_input_ids=decoder_input_ids,
   1199     encoder_outputs=encoder_outputs,
   1200     decoder_attention_mask=decoder_attention_mask,
   1201     head_mask=head_mask,
   1202     decoder_head_mask=decoder_head_mask,
   1203     cross_attn_head_mask=cross_attn_head_mask,
   1204     past_key_values=past_key_values,
   1205     decoder_inputs_embeds=decoder_inputs_embeds,
   1206     use_cache=use_cache,
   1207     output_attentions=output_attentions,
   1208     output_hidden_states=output_hidden_states,
   1209     return_dict=return_dict,
   1210 )
   1211 lm_logits = self.proj_out(outputs[0])
   1213 loss = None

File /databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:1065, in WhisperModel.forward(self, input_features, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
   1058     encoder_outputs = BaseModelOutput(
   1059         last_hidden_state=encoder_outputs[0],
   1060         hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
   1061         attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
   1062     )
   1064 # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
-> 1065 decoder_outputs = self.decoder(
   1066     input_ids=decoder_input_ids,
   1067     attention_mask=decoder_attention_mask,
   1068     encoder_hidden_states=encoder_outputs[0],
   1069     head_mask=decoder_head_mask,
   1070     cross_attn_head_mask=cross_attn_head_mask,
   1071     past_key_values=past_key_values,
   1072     inputs_embeds=decoder_inputs_embeds,
   1073     use_cache=use_cache,
   1074     output_attentions=output_attentions,
   1075     output_hidden_states=output_hidden_states,
   1076     return_dict=return_dict,
   1077 )
   1079 if not return_dict:
   1080     return decoder_outputs + encoder_outputs

File /databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:926, in WhisperDecoder.forward(self, input_ids, attention_mask, encoder_hidden_states, head_mask, cross_attn_head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    914     layer_outputs = torch.utils.checkpoint.checkpoint(
    915         create_custom_forward(decoder_layer),
    916         hidden_states,
   (...)
    922         None,  # past_key_value
    923     )
    924 else:
--> 926     layer_outputs = decoder_layer(
    927         hidden_states,
    928         attention_mask=attention_mask,
    929         encoder_hidden_states=encoder_hidden_states,
    930         layer_head_mask=(head_mask[idx] if head_mask is not None else None),
    931         cross_attn_layer_head_mask=(
    932             cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
    933         ),
    934         past_key_value=past_key_value,
    935         output_attentions=output_attentions,
    936         use_cache=use_cache,
    937     )
    938 hidden_states = layer_outputs[0]
    940 if use_cache:

File /databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:426, in WhisperDecoderLayer.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, past_key_value, output_attentions, use_cache)
    417 hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
    418     hidden_states=hidden_states,
    419     key_value_states=encoder_hidden_states,
   (...)
    423     output_attentions=output_attentions,
    424 )
    425 hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
--> 426 hidden_states = residual + hidden_states
    428 # add cross-attn to positions 3,4 of present_key_value tuple
    429 present_key_value = present_key_value + cross_attn_present_key_value
1 Like

Maybe @ybelkada will have an idea.

Any update on this? Running into the same error. I’m using gpt-j with auto mapping. The inference works with model.generate but with model(input_ids) it runs into error.

from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B",
                device_map="auto",
                load_in_8bit=False)

text = "this is a test.."
input_ids = tokenizer(text, return_tensors="pt").input_ids.to("cuda")

with torch.no_grad():
    target_ids = input_ids.clone()
    full_outputs = model(input_ids, labels=target_ids)

I’m guessing it’s because of your labels. This is a bug we will fix soon, but you can workaround it by putting your target_ids on the same device as the last layer of your model.

1 Like

Is there any update? I also meet the same problem while running openai/whisper-large-v2 with device_map='auto'.

Any update? I meet the same error while run the peft whisper model.

I encountered the same issue this week and resolved it by adding ‘cuda:0’.
Use the following case mentioned earlier.

from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B",
                device_map="auto",
                load_in_8bit=False)

text = "this is a test.."
# input_ids = tokenizer(text, return_tensors="pt").input_ids.to("cuda")
input_ids = tokenizer(text, return_tensors="pt").input_ids.to("cuda:0")

with torch.no_grad():
    target_ids = input_ids.clone()
    full_outputs = model(input_ids, labels=target_ids)