How do I do inference using the GPT models on TPUs?

I have tried the following, but it did not work:

!pip uninstall -y torch
!pip uninstall -y torchvision
!pip uninstall -y torchtext
!pip uninstall -y torchaudio

!pip install transformers \
  cloud-tpu-client==0.10 \
  datasets \
  torchvision \
  torchaudio \
  librosa \
  jiwer \
  parsivar \
  num2fawords \
  torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

!pip install datasets transformers[sentencepiece]
import torch_xla
import torch_xla.core.xla_model as xm

import torch
from torch import nn, optim
from torchvision import transforms, datasets
from torch.optim import Adam
import torch.nn.functional as F

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu

device = xm.xla_device()

from transformers import pipeline, AutoTokenizer, GPT2LMHeadModel
tokenizer = AutoTokenizer.from_pretrained('flax-community/gpt2-medium-persian')
model = GPT2LMHeadModel.from_pretrained('flax-community/gpt2-medium-persian')
model = model.to(device)
generator = pipeline('text-generation', model, tokenizer=tokenizer, config={'max_length':100})

But it errors:

out = generator('در مورد پدر هری پاتر شک هایی وجود دارد.')
Setting `pad_token_id` to `eos_token_id`:5 for open-end generation.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-f6da4d1d2479> in <module>()
----> 1 get_ipython().magic("timeit out = generator('در مورد پدر هری پاتر شک هایی وجود دارد.')")

19 frames
<decorator-gen-52> in timeit(self, line, cell)

<magic-timeit> in inner(_it, _timer)

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2041         # remove once script supports set_grad_enabled
   2042         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2043     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2044 
   2045 

RuntimeError: torch_xla/csrc/aten_xla_bridge.cpp:69 : Check failed: xtensor 
*** Begin stack trace ***
	tensorflow::CurrentStackTrace()
	torch_xla::bridge::GetXlaTensor(at::Tensor const&)
	torch_xla::AtenXlaType::index_select(at::Tensor const&, long, at::Tensor const&)
	c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, long, at::Tensor const&), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, long, at::Tensor const&> >, at::Tensor (at::Tensor const&, long, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long, at::Tensor const&)
	at::Tensor::index_select(long, at::Tensor const&) const
	at::native::embedding(at::Tensor const&, at::Tensor const&, long, bool, bool)
	torch_xla::AtenXlaType::embedding(at::Tensor const&, at::Tensor const&, long, bool, bool)
	c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, at::Tensor const&, long, bool, bool), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, long, bool, bool> >, at::Tensor (at::Tensor const&, at::Tensor const&, long, bool, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, long, bool, bool)
	at::embedding(at::Tensor const&, at::Tensor const&, long, bool, bool)
	
	_PyMethodDef_RawFastCallKeywords
	_PyCFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	
	_PyObject_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	
	_PyObject_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallDict
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	
	_PyObject_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallDict
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	PyEval_EvalCode
	
	_PyMethodDef_RawFastCallKeywords
	_PyCFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyObject_Call_Prepend
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallDict
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallDict
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
*** End stack trace ***
Input tensor is not an XLA tensor: torch.LongTensor

The pipeline function does not support TPUs, you will have to manually pass your batch through the model (after placing it on the right XLA device) and then post-process the outputs.

Are there any examples of doing this in the docs or somewhere?

You should check the task summary.