I’m trying to go over the tutorial Pipelines for inference, using a multi-GPU instance “g4dn.12xlarge”. This works fine when I set set the device_id=0
, but when I tried to use device_map="auto"
, I got “Expected all tensors to be on the same device” error.
Here’s the code I am running:
from transformers import pipeline
generator = pipeline(model="openai/whisper-large", device_map="auto")
generator(
[
"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac",
"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac",
]
)
And here’s the error stack:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:2!
File <command-2351028218258033>:1
----> 1 generator(
2 [
3 "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac",
4 "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac",
5 ]
6 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/automatic_speech_recognition.py:378, in AutomaticSpeechRecognitionPipeline.__call__(self, inputs, **kwargs)
331 def __call__(
332 self,
333 inputs: Union[np.ndarray, bytes, str],
334 **kwargs,
335 ):
336 """
337 Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`]
338 documentation for more information.
(...)
376 `"".join(chunk["text"] for chunk in output["chunks"])`.
377 """
--> 378 return super().__call__(inputs, **kwargs)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/base.py:1065, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
1061 if can_use_iterator:
1062 final_iterator = self.get_iterator(
1063 inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
1064 )
-> 1065 outputs = [output for output in final_iterator]
1066 return outputs
1067 else:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/base.py:1065, in <listcomp>(.0)
1061 if can_use_iterator:
1062 final_iterator = self.get_iterator(
1063 inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
1064 )
-> 1065 outputs = [output for output in final_iterator]
1066 return outputs
1067 else:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/pt_utils.py:124, in PipelineIterator.__next__(self)
121 return self.loader_batch_item()
123 # We're out of items within a batch
--> 124 item = next(self.iterator)
125 processed = self.infer(item, **self.params)
126 # We now have a batch of "inferred things".
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/pt_utils.py:266, in PipelinePackIterator.__next__(self)
263 return accumulator
265 while not is_last:
--> 266 processed = self.infer(next(self.iterator), **self.params)
267 if self.loader_batch_size is not None:
268 if isinstance(processed, torch.Tensor):
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/base.py:992, in Pipeline.forward(self, model_inputs, **forward_params)
990 with inference_context():
991 model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
--> 992 model_outputs = self._forward(model_inputs, **forward_params)
993 model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
994 else:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/pipelines/automatic_speech_recognition.py:562, in AutomaticSpeechRecognitionPipeline._forward(self, model_inputs, return_timestamps, generate_kwargs)
560 elif self.type == "seq2seq_whisper":
561 stride = model_inputs.pop("stride", None)
--> 562 tokens = self.model.generate(
563 input_features=model_inputs.pop("input_features"),
564 logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None,
565 **generate_kwargs,
566 )
567 out = {"tokens": tokens}
568 if stride is not None:
File /databricks/python/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/generation/utils.py:1391, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, **kwargs)
1385 raise ValueError(
1386 f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
1387 " greedy search."
1388 )
1390 # 11. run greedy search
-> 1391 return self.greedy_search(
1392 input_ids,
1393 logits_processor=logits_processor,
1394 stopping_criteria=stopping_criteria,
1395 pad_token_id=generation_config.pad_token_id,
1396 eos_token_id=generation_config.eos_token_id,
1397 output_scores=generation_config.output_scores,
1398 return_dict_in_generate=generation_config.return_dict_in_generate,
1399 synced_gpus=synced_gpus,
1400 **model_kwargs,
1401 )
1403 elif is_contrastive_search_gen_mode:
1404 if generation_config.num_return_sequences > 1:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/generation/utils.py:2179, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
2176 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2178 # forward pass to get next token
-> 2179 outputs = self(
2180 **model_inputs,
2181 return_dict=True,
2182 output_attentions=output_attentions,
2183 output_hidden_states=output_hidden_states,
2184 )
2186 if synced_gpus and this_peer_finished:
2187 continue # don't waste resources running the code we don't need
File /databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/accelerate/hooks.py:158, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
156 output = old_forward(*args, **kwargs)
157 else:
--> 158 output = old_forward(*args, **kwargs)
159 return module._hf_hook.post_forward(module, output)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:1196, in WhisperForConditionalGeneration.forward(self, input_features, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1191 if decoder_input_ids is None and decoder_inputs_embeds is None:
1192 decoder_input_ids = shift_tokens_right(
1193 labels, self.config.pad_token_id, self.config.decoder_start_token_id
1194 )
-> 1196 outputs = self.model(
1197 input_features,
1198 decoder_input_ids=decoder_input_ids,
1199 encoder_outputs=encoder_outputs,
1200 decoder_attention_mask=decoder_attention_mask,
1201 head_mask=head_mask,
1202 decoder_head_mask=decoder_head_mask,
1203 cross_attn_head_mask=cross_attn_head_mask,
1204 past_key_values=past_key_values,
1205 decoder_inputs_embeds=decoder_inputs_embeds,
1206 use_cache=use_cache,
1207 output_attentions=output_attentions,
1208 output_hidden_states=output_hidden_states,
1209 return_dict=return_dict,
1210 )
1211 lm_logits = self.proj_out(outputs[0])
1213 loss = None
File /databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:1065, in WhisperModel.forward(self, input_features, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
1058 encoder_outputs = BaseModelOutput(
1059 last_hidden_state=encoder_outputs[0],
1060 hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1061 attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1062 )
1064 # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
-> 1065 decoder_outputs = self.decoder(
1066 input_ids=decoder_input_ids,
1067 attention_mask=decoder_attention_mask,
1068 encoder_hidden_states=encoder_outputs[0],
1069 head_mask=decoder_head_mask,
1070 cross_attn_head_mask=cross_attn_head_mask,
1071 past_key_values=past_key_values,
1072 inputs_embeds=decoder_inputs_embeds,
1073 use_cache=use_cache,
1074 output_attentions=output_attentions,
1075 output_hidden_states=output_hidden_states,
1076 return_dict=return_dict,
1077 )
1079 if not return_dict:
1080 return decoder_outputs + encoder_outputs
File /databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:926, in WhisperDecoder.forward(self, input_ids, attention_mask, encoder_hidden_states, head_mask, cross_attn_head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
914 layer_outputs = torch.utils.checkpoint.checkpoint(
915 create_custom_forward(decoder_layer),
916 hidden_states,
(...)
922 None, # past_key_value
923 )
924 else:
--> 926 layer_outputs = decoder_layer(
927 hidden_states,
928 attention_mask=attention_mask,
929 encoder_hidden_states=encoder_hidden_states,
930 layer_head_mask=(head_mask[idx] if head_mask is not None else None),
931 cross_attn_layer_head_mask=(
932 cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
933 ),
934 past_key_value=past_key_value,
935 output_attentions=output_attentions,
936 use_cache=use_cache,
937 )
938 hidden_states = layer_outputs[0]
940 if use_cache:
File /databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-73443bb4-cc2e-4537-9100-b657a55cc01a/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:426, in WhisperDecoderLayer.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, past_key_value, output_attentions, use_cache)
417 hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
418 hidden_states=hidden_states,
419 key_value_states=encoder_hidden_states,
(...)
423 output_attentions=output_attentions,
424 )
425 hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
--> 426 hidden_states = residual + hidden_states
428 # add cross-attn to positions 3,4 of present_key_value tuple
429 present_key_value = present_key_value + cross_attn_present_key_value