import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from accelerate import Accelerator
accelerator = Accelerator()
device = accelerator.device
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", device_map="auto")
model = model.to(device)
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
input_ids = input_ids.to(device)
print("in: ", input_ids)
outputs = model.generate(input_ids)
#print(tokenizer.decode(outputs[0]))
and I got this error
Traceback (most recent call last):
File "demo_acc.py", line 22, in <module>
outputs = model.generate(input_ids)
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/transformers/generation/utils.py", line 1325, in generate
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/transformers/generation/utils.py", line 639, in _prepare_encoder_decoder_kwargs_for_generation
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 988, in forward
inputs_embeds = self.embed_tokens(input_ids)
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 162, in forward
return F.embedding(
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/functional.py", line 2212, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: /home/tonghengwen/code/pytorch/xla/torch_xla/csrc/aten_xla_bridge.cpp:74 : Check failed: xtensor
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::bridge::GetXlaTensor(at::Tensor const&)
torch_xla::XLANativeFunctions::index_select(at::Tensor const&, long, at::Tensor const&)
(omitted stack trace)
Py_BytesMain
__libc_start_main
*** End stack trace ***
Input tensor is not an XLA tensor: torch.LongTensor
After a few debugging, I noticed at the beginning the input_ids tensor has device attribute
in: tensor([[13959, 1566, 12, 2968, 10, 571, 625, 33, 25, 58,
1]], device='xla:0')
But at this layer
File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 988, in forward
inputs_embeds = self.embed_tokens(input_ids)
when calling T5Stack.forward the device attribute is gone
input ids tensor([[13959, 1566, 12, 2968, 10, 571, 625, 33, 25, 58,
1]])
I tried with accelerator and xla_model and both end in the same error.
Any help is much appreciated!
Thanks