Reducing latency for GPT-J

Hello

I’m using GPT-J locally on a Nvidia RTX 3090 GPU. Currently, I’m using the model in the following way:

config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", pad_token='<|endoftext|>', eos_token='<|endoftext|>', truncation_side='left')
model = GPTJForCausalLM.from_pretrained(
     "EleutherAI/gpt-j-6B",
      revision="float16",
      torch_dtype=torch.float16,
      low_cpu_mem_usage=True,
      use_cache=True,
      gradient_checkpointing=True,
 )
model.to('cuda')
prompt = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=2048)
prompt = {key: value.to('cuda') for key, value in prompt.items()}
out = model.generate(**prompt,
     n=1,
     min_length=16,
     max_new_tokens=75,
     do_sample=True,
     top_k=35,
     top_p=0.9,
     batch_size=512,
     temperature=0.75,
     no_repeat_ngram_size=4,
     clean_up_tokenization_spaces=True,
     use_cache=True,
     pad_token_id=tokenizer.eos_token_id
 )
res = tokenizer.decode(out[0])

As input to the model I’m using 2048 tokens and I produce 75 tokens as output. The latency is around 4-5 seconds. In the following blog post, I’ve read that using pipelines latency can be improved and that tokenization can be a bottleneck.

Can the tokenization be improved for my code and would using a pipeline reduce the latency? Are there any other things I can do to reduce the latency?

2 Likes

Hi,

Are you managing to run it on a single GPU? If yes, using TensorRT might be your best bet for top notch acceleration. For this, there are several path:

You could try out Optimum library for an easy integration and use of ONNX + ONNX Runtime, but given the size of the model, I can’t give guarantee as to how well it would perform – it’s also something I want to try. In any case we may want to have https://github.com/huggingface/optimum/pull/255 Merge decoder and decoder with past to reduce memory usage · Issue #530 · huggingface/optimum · GitHub solved before this is usable.

I don’t have a strong experience with mixed precision, but you should be able to use it there as well.

I have no idea about the questions on tokenizer / pipeline, but I would guess they are not the bottleneck here. You could do a profiling to see.

Thank you very much for this great answer.

The second option you mentioned is probably easiest as Huggingface directly supports conversion of the GPT-J model to ONNX.

Another possibility would be to use DeepSpeed to speed up inference. Do you know if ONNX would provide more speed up than DeepSpeed?

About DeepSpeed Inference, I think it’s more geared towards parallelism over several GPUs, but maybe there are more features to it: Getting Started with DeepSpeed for Inferencing Transformer based Models - DeepSpeed

I think DeepSpeed can also provide benefits for single GPU inference speed up.

I found a tutorial describing how to export GPT-J to ONNX but I did not find a tutorial how to use it then with ONNX Runtime. Do you know a tutorial or can you provide an example?

Hi @Eichhof , to use GPT-J with ONNX Runtime, the simplest option is to use Optimum library that has an ONNX Runtime integration: 🤗 Optimum

The workflow would be:

python -m optimum.exporters.onnx --task causal-lm-with-past --for-ort --model EleutherAI/gpt-j-6B gptj_onnx/
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("gptj_onnx")
model = ORTModelForCausalLM.from_pretrained("gptj_onnx", provider="TensorrtExecutionProvider")

inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt")
gen_tokens = model.generate(**inputs)

print(tokenizer.batch_decode(gen_tokens))

Reference: Accelerated inference on NVIDIA GPUs

Several disclaimers:

So you may encounter errors at this stage. I think in the next release the support will be better for large models.

I have been looking on converting GPT-J to ONNX and run it on my RTX 3090. I do not know how to convert it to float16 to be able to run it on my card.

Can you please guide me to the right direction?

from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("gptj_onnx")
model = ORTModelForCausalLM.from_pretrained("gptj_onnx", provider="TensorrtExecutionProvider")

inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt")
gen_tokens = model.generate(**inputs)

print(tokenizer.batch_decode(gen_tokens))

Is this running as float16 as mentioned by shazde? Or else how is this possible?

Hi @fxmarty

Thank you very much for the code, I appreciate your help. I tried it out.

First, I was running python -m optimum.exporters.onnx --task causal-lm-with-past --for-ort --model EleutherAI/gpt-j-6B gptj_onnx/ but --for-ort is not recognized. Do you know why?

Second, without --for-ort it works but I’m getting the error shown below. Do you know this error and why it happens?

By the way, I did the installation with pip install optimum[onnxruntime-gpu].

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJModel: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing GPTJModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Framework not specified. Using pt to export to ONNX.
Using framework PyTorch: 1.12.1
Overriding 2 configuration item(s)
        - use_cache -> True
        - pad_token_id -> 0
C:\Users\myUsername\Anaconda3\envs\huggingface\lib\site-packages\transformers\models\gptj\modeling_gptj.py:597: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if batch_size <= 0:
C:\Users\myUsername\Anaconda3\envs\huggingface\lib\site-packages\transformers\models\gptj\modeling_gptj.py:177: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
Validating ONNX model...
        -[✓] ONNX model output names match reference model (present.22.value, present.15.key, present.15.value, present.25.value, present.9.value, present.26.value, present.8.value, present.13.key, present.27.key, present.6.value, present.7.value, present.12.value, present.24.key, present.1.value, present.4.key, logits, present.10.key, present.9.key, present.16.key, present.0.key, present.19.key, present.21.key, present.4.value, present.23.value, present.3.key, present.17.key, present.6.key, present.21.value, present.22.key, present.18.key, present.11.key, present.10.value, present.14.value, present.0.value, present.13.value, present.14.key, present.5.value, present.2.value, present.16.value, present.24.value, present.25.key, present.27.value, present.8.key, present.7.key, present.19.value, present.20.key, present.26.key, present.18.value, present.23.key, present.11.value, present.2.key, present.5.key, present.3.value, present.1.key, present.20.value, present.17.value, present.12.key)
        - Validating ONNX Model output "logits":
                -[✓] (2, 16, 50400) matches (2, 16, 50400)
                -[x] values not close enough, max diff: 3.2901763916015625e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.0.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.0.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.1.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.1.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.2.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.2.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.3.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.3.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.4.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.4.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.5.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.5.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.6.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.6.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.7.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 2.6702880859375e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.7.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.8.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.8.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.9.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 2.09808349609375e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.9.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 1.3589859008789062e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.10.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.10.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 1.2636184692382812e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.11.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 2.574920654296875e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.11.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.12.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.12.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.13.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 1.71661376953125e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.13.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.14.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.14.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.15.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.15.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.16.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.16.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.17.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 1.71661376953125e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.17.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.18.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.18.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.19.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.19.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.20.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.20.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.21.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.21.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.22.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.22.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.23.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.23.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.24.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.24.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.25.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.25.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 1.1682510375976562e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.26.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.26.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.27.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.27.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
An error occured, but the model was saved at: gptj_onnx/model.onnx