I have tried the following, but it did not work:
!pip uninstall -y torch
!pip uninstall -y torchvision
!pip uninstall -y torchtext
!pip uninstall -y torchaudio
!pip install transformers \
cloud-tpu-client==0.10 \
datasets \
torchvision \
torchaudio \
librosa \
jiwer \
parsivar \
num2fawords \
torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!pip install datasets transformers[sentencepiece]
import torch_xla
import torch_xla.core.xla_model as xm
import torch
from torch import nn, optim
from torchvision import transforms, datasets
from torch.optim import Adam
import torch.nn.functional as F
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
device = xm.xla_device()
from transformers import pipeline, AutoTokenizer, GPT2LMHeadModel
tokenizer = AutoTokenizer.from_pretrained('flax-community/gpt2-medium-persian')
model = GPT2LMHeadModel.from_pretrained('flax-community/gpt2-medium-persian')
model = model.to(device)
generator = pipeline('text-generation', model, tokenizer=tokenizer, config={'max_length':100})
But it errors:
out = generator('در مورد پدر هری پاتر شک هایی وجود دارد.')
Setting `pad_token_id` to `eos_token_id`:5 for open-end generation.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-18-f6da4d1d2479> in <module>()
----> 1 get_ipython().magic("timeit out = generator('در مورد پدر هری پاتر شک هایی وجود دارد.')")
19 frames
<decorator-gen-52> in timeit(self, line, cell)
<magic-timeit> in inner(_it, _timer)
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
2041 # remove once script supports set_grad_enabled
2042 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2043 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
2044
2045
RuntimeError: torch_xla/csrc/aten_xla_bridge.cpp:69 : Check failed: xtensor
*** Begin stack trace ***
tensorflow::CurrentStackTrace()
torch_xla::bridge::GetXlaTensor(at::Tensor const&)
torch_xla::AtenXlaType::index_select(at::Tensor const&, long, at::Tensor const&)
c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, long, at::Tensor const&), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, long, at::Tensor const&> >, at::Tensor (at::Tensor const&, long, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long, at::Tensor const&)
at::Tensor::index_select(long, at::Tensor const&) const
at::native::embedding(at::Tensor const&, at::Tensor const&, long, bool, bool)
torch_xla::AtenXlaType::embedding(at::Tensor const&, at::Tensor const&, long, bool, bool)
c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, at::Tensor const&, long, bool, bool), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, long, bool, bool> >, at::Tensor (at::Tensor const&, at::Tensor const&, long, bool, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, long, bool, bool)
at::embedding(at::Tensor const&, at::Tensor const&, long, bool, bool)
_PyMethodDef_RawFastCallKeywords
_PyCFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
_PyObject_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
_PyObject_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallDict
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
_PyObject_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallDict
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
_PyMethodDef_RawFastCallKeywords
_PyCFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallDict
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallDict
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
*** End stack trace ***
Input tensor is not an XLA tensor: torch.LongTensor