How to run T5 with Accelerator/XLA

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from accelerate import Accelerator
accelerator = Accelerator()

device = accelerator.device


tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", device_map="auto")
model = model.to(device)


input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
input_ids = input_ids.to(device)

print("in: ", input_ids)

outputs = model.generate(input_ids)

#print(tokenizer.decode(outputs[0]))

and I got this error

Traceback (most recent call last):
  File "demo_acc.py", line 22, in <module>
    outputs = model.generate(input_ids)
  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/transformers/generation/utils.py", line 1325, in generate
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/transformers/generation/utils.py", line 639, in _prepare_encoder_decoder_kwargs_for_generation
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 988, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/functional.py", line 2212, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: /home/tonghengwen/code/pytorch/xla/torch_xla/csrc/aten_xla_bridge.cpp:74 : Check failed: xtensor 
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::bridge::GetXlaTensor(at::Tensor const&)
	torch_xla::XLANativeFunctions::index_select(at::Tensor const&, long, at::Tensor const&)
	(omitted stack trace)
	Py_BytesMain
	__libc_start_main
	
*** End stack trace ***
Input tensor is not an XLA tensor: torch.LongTensor


After a few debugging, I noticed at the beginning the input_ids tensor has device attribute

in:  tensor([[13959,  1566,    12,  2968,    10,   571,   625,    33,    25,    58,
             1]], device='xla:0')

But at this layer

  File "/home/tonghengwen/anaconda3/envs/pytorch/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 988, in forward
    inputs_embeds = self.embed_tokens(input_ids)

when calling T5Stack.forward the device attribute is gone

input ids tensor([[13959,  1566,    12,  2968,    10,   571,   625,    33,    25,    58,
             1]])

I tried with accelerator and xla_model and both end in the same error.

Any help is much appreciated!

Thanks