Using XLA fast text generation with Pegasus models

Hi All

I have a simple code to paraphrase English texts using a fine-tuned model called tuner007/pegasus_paraphrase. The model itself uses PEGASUS.

Here’s the sample code

from transformers import AutoModelForSeq2SeqLM , AutoTokenizer
ph_model_name = "tuner007/pegasus_paraphrase"

ph_tokenizer = AutoTokenizer.from_pretrained(ph_model_name)
ph_model = AutoModelForSeq2SeqLM.from_pretrained(ph_model_name)

def get_response(input_text, num_return_sequences=2, num_beams=7, max_length=512, temperature=0.7):

  batch = ph_tokenizer([input_text], truncation=True,
                       padding="longest", max_length=max_length,

  translated = ph_model.generate(**batch, max_length=max_length,
                                 num_beams=num_beams, num_return_sequences=num_return_sequences,
                                 temperature=temperature, do_sample=True, top_k=90, top_p=0.95,
  tgt_text = ph_tokenizer.batch_decode(translated, skip_special_tokens=True)
  return tgt_text

get_response('The world has been inching toward fully autonomous cars for years')

The output will be

['The world has been working on fully self-driving cars.',
 'The world has been working on fully self-driving cars.']

Now, I am following this blog post to use the new XLA support to accelerate the inference time.

I am trying to run it using the following code

from transformers import AutoModelForSeq2SeqLM , AutoTokenizer
import tensorflow as tf

from tensorflow.python.ops.numpy_ops import np_config

ph_model_name = "tuner007/pegasus_paraphrase"

# torch_device = "cuda:0"
ph_tokenizer = AutoTokenizer.from_pretrained(ph_model_name)
ph_model = AutoModelForSeq2SeqLM.from_pretrained(ph_model_name)

tokenization_kwargs = {"max_length": 512, "padding": "longest", "truncation": True, "return_tensors": "tf"}
generation_kwargs = {"num_beams": 7, "max_length": 512,
                     "num_return_sequences":2, "temperature":0.7,
                     "do_sample": True, "top_k": 90, "top_p": 0.95,
                     "no_repeat_ngram_size": 2, "early_stopping": True}

# generate a paraphrased text
xla_generate = tf.function(ph_model.generate, jit_compile=True)

input_prompt = 'the world has been inching toward fully autonomous cars for years .'
tokenized_inputs = ph_tokenizer([input_prompt], **tokenization_kwargs)

generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
decoded_text = ph_tokenizer.decode(generated_text[0], skip_special_tokens=True)


It gives the following error

TypeError                                 Traceback (most recent call last)
<ipython-input-8-a25ee9d320ec> in <module>
      1 input_prompt = 'the world has been inching toward fully autonomous cars for years .'
      2 tokenized_inputs = ph_tokenizer([input_prompt], **tokenization_kwargs)
----> 3 generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
      4 decoded_text = ph_tokenizer.decode(generated_text[0], skip_special_tokens=True)
      5 print(decoded_text)

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ in autograph_handler(*args, **kwargs)
   1145           except Exception as e:  # pylint:disable=broad-except
   1146             if hasattr(e, "ag_error_metadata"):
-> 1147               raise e.ag_error_metadata.to_exception(e)
   1148             else:
   1149               raise

TypeError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/torch/autograd/", line 847, in decorate_context  *
        return func(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/transformers/", line 1182, in generate  *
        model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
    File "/usr/local/lib/python3.7/dist-packages/transformers/", line 525, in _prepare_encoder_decoder_kwargs_for_generation  *
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
    File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/", line 1130, in _call_impl  *
        return forward_call(*input, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/transformers/models/pegasus/", line 753, in forward  *
        input_shape = input_ids.size()

    TypeError: 'numpy.int64' object is not callable

Does anyone know where’s the error?

Did Anyone try an alternative way to use XLA with this model?

Thank you


I see you’re using AutoModelForSeq2SeqLM but you need to use TFAutoModelForSeq2SeqLM. I see that the repo you’re trying to load only has a PyTorch checkpoint, so the proper way to load it would be:

from transformers import TFAutoModelForSeq2SeqLM

model = TFAutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase", from_pt=True)

Thanks for your reply

I used TFAutoModelForSeq2SeqLM, but it now gives this error

AttributeError: 'Tensor' object has no attribute 'numpy'

Can you help?


It runs just fine for me. Created a Colab for you: Google Colab


Thank you for your work.
I tested the codes and it runs now.

I just discovered two things

when I run the code for the second time, it takes a long time too !!
In the next figure, you can see that I run the same code again, and the time exceeds 1 minute !!

The second thing, I think one of the model keywords arguments which causes this error

AttributeError: 'Tensor' object has no attribute 'numpy'

I tried to control it by removing some arguments from the following list generation_kwargs , but can not figure out which argument causes the error

padding_kwargs = {"pad_to_multiple_of": 8, "padding": True}
generation_kwargs = {"num_beams": 7, "max_length": 512,
                     "num_return_sequences":2, "temperature":0.7,
                     "do_sample": True, "top_k": 90, "top_p": 0.95,
                     "no_repeat_ngram_size": 2, "early_stopping": True}

xla_generate(**tokenized_input, **generation_kwargs)

Last question.
will changing the input text require re-compiling the model generation and takes a lot of time again?
Or the first run/compilation does not depend on the input text itself?

Thank you

I’d recommend reading this part of @joaogante’s blog: Faster Text Generation with TensorFlow and XLA

1 Like