Speeding up GPT2 generation

System Setup

  • Pop!_OS 20.04
  • Pytorch: 1.5.1
  • Transformers: 3.0.2
  • Python: 3.7.6

Background Code

from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import time
import functools

def time_gpt2_gen():
    prompt1 = 'We present an update on the results of the Double Chooz experiment. Double Chooz searches for the neutrino mixing angle, θ13, in the three-neutrino mixing matrix via the disappearance of produced by the dual 4.27 GW/th Chooz B Reactors. Here we discuss updated oscillation fit results using both the rate and the shape of the anti-neutrino energy spectrum. In the most recent oscillation analysis we included data with neutron captures on Gadolinium and Hydrogen along with the reactor off data that we collected. This is an important step in our multi-year program to establish the value of θ13.'
    prompt2 = 'The paper covers detailed discussion on novel control system developed for adaptive fluid-based shock-absorbers serving for mitigation of unknown impact excitations. In order to provide complete independence of the control system from the loading conditions, the Hybrid Prediction Control (HPC) was elaborated. The proposed method is an extension of previously introduced kinematic feedback control which ensures optimal path finding, tracking and path update in case of high disturbance or sudden change of loading conditions. Implementation of the presented control system allows to obtain self-adaptive fluid-based absorbers providing robust impact mitigation. In contrast to previously developed methods of Adaptive Impact Absorption, the proposed control strategy does not require prior knowledge of impact excitation or its preliminary identification. The independence of applied control system from parameters of impact loading results in the capability of automatic path correction in the case of disturbance occurrence and re-adaptation to a number of subsequent impacts. The successful operation of the self-adaptive system is investigated with the use of numerical examples involving double-chamber pneumatic shock-absorber equipped with controllable valve. Efficiency of the HPC is proved by comparison with passive absorber as well as device equipped with adaptive and optimal control modules.'
    prompt3 = 'This study aimed to produce biosurfactant from Pseudozyma tsukubaensis using cassava wastewater and an inoculum (biomass) for galactooligosaccharides synthesis from lactose as an integrated system. First, the use of cassava wastewater as a low cost culture medium by P. tsukubaensis to produce biomass and biosurfactant was evaluated and optimized. Then, the microbial cells (biomass) obtained from the optimized process were used to produce galactooligosaccharides from lactose. The optimum conditions for biosurfactant and biomass synthesis were found to be 80% (v/v) of cassava wastewater at 30°C and 200rpm for 48h. The highest concentration of biosurfactant, that is, minimum surface tension value and maximum biomass concentration predicted were experimentally confirmed as 26.87mN/m and 10.5g/L, respectively. The biosurfactant obtained showed good thermal (121°C/1h), pH (2–11) and ionic strength (0–25% NaCl) stability. Excellent emulsifier activity was also verified, suggesting a potential application in enhanced oil recovery. Galactooligosaccharides synthesized by the Kluyveromyces genus have been extensively investigated, however, few studies have reported transgalactosylation ability by other yeast genera. The transgalactosylation activity of the yeast biomass at optimized conditions from 40% (w/w) lactose resulted in galactooligosaccharides production of 73.12g/L and a yield of 18.28% (w/w) at pH 8.0 and 30°C in 24h. This research showed the technical feasibility of an integrated process: biosurfactant and GOS production from P. tsukubaensis, which takes advantage of the remarkable metabolism of this microorganism. To the best of our knowledge, this is the first study reporting the potential of P. tsukubaensis to produce two economical biotechnological products of increase interest as an integrated process.'
    prompt4 = 'Advantages of a fuzzy predictive control algorithm are discussed in the paper. The fuzzy predictive algorithm is a combination of a DMC (Dynamic Matrix Control) algorithm and Takagi–Sugeno fuzzy modeling, thus it inherits advantages of both techniques. The algorithm is numerically effective. It is in fact generalization of the standard DMC algorithm widely used in the industry, thus the existing implementations of the DMC algorithm can be extended using the presented fuzzy approach. A simple and easy to apply method of fuzzy predictive control algorithms synthesis is presented in the paper. It can be easy applied also in the case of Multiple Input Multiple Output (MIMO) control plants. Moreover, information about measured disturbance can be included in the algorithms in an easy way. The advantages of the fuzzy predictive control algorithm are demonstrated in the example control systems of two nonlinear chemical reactors: the first one—with inverse response and the second one—a MIMO plant with time delay.'
    batch = [prompt1, prompt2, prompt3, prompt4]

    tokenizer = run_func(GPT2Tokenizer.from_pretrained, 'gpt2', print_str='Initialize GPT2Tokenizer', padding_side='right')
    tokenizer.pad_token = tokenizer.eos_token
    encoded_results = run_func(tokenizer, batch, print_str='Calculate initial encodings', padding=True, truncation=True, return_tensors='pt', return_attention_mask=True)

    gpt2 = run_func(GPT2LMHeadModel.from_pretrained, 'gpt2', print_str='Initialize GPT2LMHeadModel')

    temperature = 0.92
    tmp_input_ids = encoded_results['input_ids']
    tmp_attention_mask = encoded_results['attention_mask']
    max_gen_length = 30
    counter = 0
    gen_dict = {'a1': '', 'a2': '', 'a3': '', 'a4': ''}
    while_tic = time.perf_counter()
    while counter < max_gen_length:
        print('\ncounter = {}'.format(counter))

        outputs = run_func(gpt2, print_str='   Calculate GPT2 outputs', input_ids=tmp_input_ids,
                           attention_mask=tmp_attention_mask
                           )

        # (batch_size, sequence_length, vocab_size)
        lm_logits_w_temp = outputs[0] / temperature

        # (batch_size, vocab_size)
        last_tokens = lm_logits_w_temp[:, -1, :]
        last_token_softmaxes = run_func(torch.softmax, last_tokens, print_str='   Last token softmax', dim=-1).squeeze()

        next_tokens = run_func(torch.multinomial, last_token_softmaxes, print_str='   Generate next token', num_samples=1)

        list_comp_tic = time.perf_counter()
        next_strs = [tokenizer.decode(next_token).strip() for next_token in next_tokens]
        prev_input_strs = [tokenizer.decode(id_tensor, skip_special_tokens=True) for id_tensor in tmp_input_ids]
        prev_split_list = [prev_input_str.split() for prev_input_str in prev_input_strs]
        list_comp_toc = time.perf_counter()
        print('   List comprehension calcs elapsed time: {} seconds'.format(list_comp_toc-list_comp_tic))

        gen_dict['a1'] += next_strs[0] + ' '
        gen_dict['a2'] += next_strs[1] + ' '
        gen_dict['a3'] += next_strs[2] + ' '
        gen_dict['a4'] += next_strs[3] + ' '
        str_list_to_join = []

        next_strs_tic = time.perf_counter()
        for ii, prev_split2 in enumerate(prev_split_list):
            next_str = next_strs[ii]
            tmp_prev = prev_split2
            tmp_prev.append(next_str)
            str_list_to_join.append(tmp_prev)
        next_inputs = [' '.join(str_to_join) for str_to_join in str_list_to_join]
        next_strs_toc = time.perf_counter()
        print('   Add generated tokens onto previous full strings elapsed time: {} seconds'.format(next_strs_toc-next_strs_tic))

        if counter == max_gen_length - 1:
            final_str_batch = next_inputs
        else:
            new_encoded_results = run_func(tokenizer, next_inputs, print_str='   Tokenizing next full strings', padding=True, truncation=True, return_tensors='pt',
                                            return_attention_mask=True)
            tmp_input_ids = new_encoded_results['input_ids']
            tmp_attention_mask = new_encoded_results['attention_mask']

        counter += 1

    while_toc = time.perf_counter()
    print('Time to complete while loop for {} passes: {} seconds'.format(max_gen_length, while_toc-while_tic))
    print('------------------------------------------------------------------------------------')

run_func(time_gpt2_gen, print_str='Total time')

Question
I was wondering what I could do to speed up the generation of GPT2? I feel like this is a pretty naive implementation so I’d love to here feedback on some speed optimization strategies. Here is the output for generating 30 new tokens on a batch of 4 scientific abstracts. It takes about 4.8 minutes.

Initialize GPT2Tokenizer elapsed time: 1.4727641880017472 seconds
Calculate initial encodings elapsed time: 0.03799170200363733 seconds
Initialize GPT2LMHeadModel elapsed time: 18.23160586800077 seconds

counter = 0
   Calculate GPT2 outputs elapsed time: 9.893233462003991 seconds
   Last token softmax elapsed time: 0.0007064949968480505 seconds
   Generate next token elapsed time: 0.0007856059964979067 seconds
   List comprehension calcs elapsed time: 0.13995413600059692 seconds
   Add generated tokens onto previous full strings elapsed time: 3.628500417107716e-05 seconds
   Tokenizing next full strings elapsed time: 0.012479234996135347 seconds

counter = 1
   Calculate GPT2 outputs elapsed time: 8.576971583999693 seconds
   Last token softmax elapsed time: 0.0006583049980690703 seconds
   Generate next token elapsed time: 0.0007301439982256852 seconds
   List comprehension calcs elapsed time: 0.12174680699536111 seconds
   Add generated tokens onto previous full strings elapsed time: 4.426100349519402e-05 seconds
   Tokenizing next full strings elapsed time: 0.011608965003688354 seconds

counter = 2
   Calculate GPT2 outputs elapsed time: 8.446395003004 seconds
   Last token softmax elapsed time: 0.0006648650014540181 seconds
   Generate next token elapsed time: 0.0007324170001083985 seconds
   List comprehension calcs elapsed time: 0.14154231199790956 seconds
   Add generated tokens onto previous full strings elapsed time: 5.113699444336817e-05 seconds
   Tokenizing next full strings elapsed time: 0.014487013999314513 seconds

counter = 3
   Calculate GPT2 outputs elapsed time: 9.05919142899802 seconds
   Last token softmax elapsed time: 0.0007119319998309948 seconds
   Generate next token elapsed time: 0.0007321929952013306 seconds
   List comprehension calcs elapsed time: 0.1303824819988222 seconds
   Add generated tokens onto previous full strings elapsed time: 4.862200148636475e-05 seconds
   Tokenizing next full strings elapsed time: 0.012177805001556408 seconds

counter = 4
   Calculate GPT2 outputs elapsed time: 8.18601783199847 seconds
   Last token softmax elapsed time: 0.02430849100346677 seconds
   Generate next token elapsed time: 0.0007163990012486465 seconds
   List comprehension calcs elapsed time: 0.15182566899602534 seconds
   Add generated tokens onto previous full strings elapsed time: 4.859599721385166e-05 seconds
   Tokenizing next full strings elapsed time: 0.014472196999122389 seconds

counter = 5
   Calculate GPT2 outputs elapsed time: 8.919605226001295 seconds
   Last token softmax elapsed time: 0.0006643329979851842 seconds
   Generate next token elapsed time: 0.0007834899952285923 seconds
   List comprehension calcs elapsed time: 0.12023409699759213 seconds
   Add generated tokens onto previous full strings elapsed time: 5.939300172030926e-05 seconds
   Tokenizing next full strings elapsed time: 0.011963372999161948 seconds

counter = 6
   Calculate GPT2 outputs elapsed time: 7.393029450999165 seconds
   Last token softmax elapsed time: 0.0007156190040404908 seconds
   Generate next token elapsed time: 0.0007048089973977767 seconds
   List comprehension calcs elapsed time: 0.13225654000416398 seconds
   Add generated tokens onto previous full strings elapsed time: 4.454200097825378e-05 seconds
   Tokenizing next full strings elapsed time: 0.012416008998116013 seconds

counter = 7
   Calculate GPT2 outputs elapsed time: 7.838971406003111 seconds
   Last token softmax elapsed time: 0.0023600620042998344 seconds
   Generate next token elapsed time: 0.0012576709996210411 seconds
   List comprehension calcs elapsed time: 0.1481078790020547 seconds
   Add generated tokens onto previous full strings elapsed time: 5.014600174035877e-05 seconds
   Tokenizing next full strings elapsed time: 0.017490088001068216 seconds

counter = 8
   Calculate GPT2 outputs elapsed time: 9.155262512998888 seconds
   Last token softmax elapsed time: 0.0010118979989783838 seconds
   Generate next token elapsed time: 0.000767368997912854 seconds
   List comprehension calcs elapsed time: 0.1483391650035628 seconds
   Add generated tokens onto previous full strings elapsed time: 9.162300557363778e-05 seconds
   Tokenizing next full strings elapsed time: 0.012755158000800293 seconds

counter = 9
   Calculate GPT2 outputs elapsed time: 10.77434083299886 seconds
   Last token softmax elapsed time: 0.010747615997388493 seconds
   Generate next token elapsed time: 0.0007793960030539893 seconds
   List comprehension calcs elapsed time: 0.14875024900538847 seconds
   Add generated tokens onto previous full strings elapsed time: 5.4737000027671456e-05 seconds
   Tokenizing next full strings elapsed time: 0.01669158499862533 seconds

counter = 10
   Calculate GPT2 outputs elapsed time: 10.110196126996016 seconds
   Last token softmax elapsed time: 0.02143118399544619 seconds
   Generate next token elapsed time: 0.0010516420006752014 seconds
   List comprehension calcs elapsed time: 0.1679224549952778 seconds
   Add generated tokens onto previous full strings elapsed time: 8.338599582202733e-05 seconds
   Tokenizing next full strings elapsed time: 0.014697657999931835 seconds

counter = 11
   Calculate GPT2 outputs elapsed time: 9.811458320000384 seconds
   Last token softmax elapsed time: 0.000684321996232029 seconds
   Generate next token elapsed time: 0.0007668550024391152 seconds
   List comprehension calcs elapsed time: 0.1525734469978488 seconds
   Add generated tokens onto previous full strings elapsed time: 8.502999844495207e-05 seconds
   Tokenizing next full strings elapsed time: 0.015574641001876444 seconds

counter = 12
   Calculate GPT2 outputs elapsed time: 10.353367308998713 seconds
   Last token softmax elapsed time: 0.0010184349957853556 seconds
   Generate next token elapsed time: 0.0007333970061154105 seconds
   List comprehension calcs elapsed time: 0.13797010699636303 seconds
   Add generated tokens onto previous full strings elapsed time: 5.920600233366713e-05 seconds
   Tokenizing next full strings elapsed time: 0.01412803399580298 seconds

counter = 13
   Calculate GPT2 outputs elapsed time: 10.826486637000926 seconds
   Last token softmax elapsed time: 0.0030568980000680313 seconds
   Generate next token elapsed time: 0.0016750930008129217 seconds
   List comprehension calcs elapsed time: 0.13814785299473442 seconds
   Add generated tokens onto previous full strings elapsed time: 5.037700611865148e-05 seconds
   Tokenizing next full strings elapsed time: 0.013047558997641318 seconds

counter = 14
   Calculate GPT2 outputs elapsed time: 7.55167671199888 seconds
   Last token softmax elapsed time: 0.0005718500033253804 seconds
   Generate next token elapsed time: 0.0007639390023541637 seconds
   List comprehension calcs elapsed time: 0.13264076500490773 seconds
   Add generated tokens onto previous full strings elapsed time: 5.1554998208303005e-05 seconds
   Tokenizing next full strings elapsed time: 0.013588694004283752 seconds

counter = 15
   Calculate GPT2 outputs elapsed time: 7.818915773998015 seconds
   Last token softmax elapsed time: 0.0005817519995616749 seconds
   Generate next token elapsed time: 0.0009528160007903352 seconds
   List comprehension calcs elapsed time: 0.13134208500559907 seconds
   Add generated tokens onto previous full strings elapsed time: 5.866299761692062e-05 seconds
   Tokenizing next full strings elapsed time: 0.01613930299936328 seconds

counter = 16
   Calculate GPT2 outputs elapsed time: 8.578778709998005 seconds
   Last token softmax elapsed time: 0.006186169004649855 seconds
   Generate next token elapsed time: 0.0007223930006148294 seconds
   List comprehension calcs elapsed time: 0.148166503997345 seconds
   Add generated tokens onto previous full strings elapsed time: 5.607099592452869e-05 seconds
   Tokenizing next full strings elapsed time: 0.015603724998072721 seconds

counter = 17
   Calculate GPT2 outputs elapsed time: 8.770265252001991 seconds
   Last token softmax elapsed time: 0.0005352449952624738 seconds
   Generate next token elapsed time: 0.0006981940023251809 seconds
   List comprehension calcs elapsed time: 0.13802178599871695 seconds
   Add generated tokens onto previous full strings elapsed time: 4.810000245925039e-05 seconds
   Tokenizing next full strings elapsed time: 0.013090499996906146 seconds

counter = 18
   Calculate GPT2 outputs elapsed time: 9.192157427001803 seconds
   Last token softmax elapsed time: 0.0013120759977027774 seconds
   Generate next token elapsed time: 0.0006921610038261861 seconds
   List comprehension calcs elapsed time: 0.1327957250032341 seconds
   Add generated tokens onto previous full strings elapsed time: 4.9251000746153295e-05 seconds
   Tokenizing next full strings elapsed time: 0.013492570993548725 seconds

counter = 19
   Calculate GPT2 outputs elapsed time: 8.60694089000026 seconds
   Last token softmax elapsed time: 0.003325685000163503 seconds
   Generate next token elapsed time: 0.000762072995712515 seconds
   List comprehension calcs elapsed time: 0.1455982879997464 seconds
   Add generated tokens onto previous full strings elapsed time: 5.135399987921119e-05 seconds
   Tokenizing next full strings elapsed time: 0.01226243400014937 seconds

counter = 20
   Calculate GPT2 outputs elapsed time: 7.204778202001762 seconds
   Last token softmax elapsed time: 0.000613894997513853 seconds
   Generate next token elapsed time: 0.0007320740041905083 seconds
   List comprehension calcs elapsed time: 0.13817419400584185 seconds
   Add generated tokens onto previous full strings elapsed time: 5.026099825045094e-05 seconds
   Tokenizing next full strings elapsed time: 0.011972155996772926 seconds

counter = 21
   Calculate GPT2 outputs elapsed time: 9.892961628000194 seconds
   Last token softmax elapsed time: 0.0006187749968376011 seconds
   Generate next token elapsed time: 0.0007513390009989962 seconds
   List comprehension calcs elapsed time: 0.12703227200108813 seconds
   Add generated tokens onto previous full strings elapsed time: 5.13460036017932e-05 seconds
   Tokenizing next full strings elapsed time: 0.012355157996353228 seconds

counter = 22
   Calculate GPT2 outputs elapsed time: 8.470063239998126 seconds
   Last token softmax elapsed time: 0.0005823390019941144 seconds
   Generate next token elapsed time: 0.0008517509995726869 seconds
   List comprehension calcs elapsed time: 0.15461847899860004 seconds
   Add generated tokens onto previous full strings elapsed time: 5.156599945621565e-05 seconds
   Tokenizing next full strings elapsed time: 0.013039129000389948 seconds

counter = 23
   Calculate GPT2 outputs elapsed time: 7.189594238996506 seconds
   Last token softmax elapsed time: 0.0006244809992494993 seconds
   Generate next token elapsed time: 0.0008493330024066381 seconds
   List comprehension calcs elapsed time: 0.13466514900210313 seconds
   Add generated tokens onto previous full strings elapsed time: 5.236799916019663e-05 seconds
   Tokenizing next full strings elapsed time: 0.012745897998684086 seconds

counter = 24
   Calculate GPT2 outputs elapsed time: 8.126932711005793 seconds
   Last token softmax elapsed time: 0.0006478180002886802 seconds
   Generate next token elapsed time: 0.0011431679959059693 seconds
   List comprehension calcs elapsed time: 0.14799942199897487 seconds
   Add generated tokens onto previous full strings elapsed time: 6.185499660205096e-05 seconds
   Tokenizing next full strings elapsed time: 0.01267693200497888 seconds

counter = 25
   Calculate GPT2 outputs elapsed time: 7.3953299740023795 seconds
   Last token softmax elapsed time: 0.0005992939986754209 seconds
   Generate next token elapsed time: 0.0007862490019761026 seconds
   List comprehension calcs elapsed time: 0.14582869299920276 seconds
   Add generated tokens onto previous full strings elapsed time: 5.401299858931452e-05 seconds
   Tokenizing next full strings elapsed time: 0.013731771003222093 seconds

counter = 26
   Calculate GPT2 outputs elapsed time: 8.018481942002836 seconds
   Last token softmax elapsed time: 0.04562465199705912 seconds
   Generate next token elapsed time: 0.0008113009971566498 seconds
   List comprehension calcs elapsed time: 0.13469937299669255 seconds
   Add generated tokens onto previous full strings elapsed time: 5.5969001550693065e-05 seconds
   Tokenizing next full strings elapsed time: 0.013372118999541271 seconds

counter = 27
   Calculate GPT2 outputs elapsed time: 8.428962316000252 seconds
   Last token softmax elapsed time: 0.0005464290006784722 seconds
   Generate next token elapsed time: 0.0007245989982038736 seconds
   List comprehension calcs elapsed time: 0.13549309900554363 seconds
   Add generated tokens onto previous full strings elapsed time: 4.968699795426801e-05 seconds
   Tokenizing next full strings elapsed time: 0.013636056995892432 seconds

counter = 28
   Calculate GPT2 outputs elapsed time: 7.557854796999891 seconds
   Last token softmax elapsed time: 0.0006319280000752769 seconds
   Generate next token elapsed time: 0.0007067559999995865 seconds
   List comprehension calcs elapsed time: 0.14673257899994496 seconds
   Add generated tokens onto previous full strings elapsed time: 5.2330004109535366e-05 seconds
   Tokenizing next full strings elapsed time: 0.012924063004902564 seconds

counter = 29
   Calculate GPT2 outputs elapsed time: 7.523497490998125 seconds
   Last token softmax elapsed time: 0.0005815810000058264 seconds
   Generate next token elapsed time: 0.0008171070003299974 seconds
   List comprehension calcs elapsed time: 0.14285911399929319 seconds
   Add generated tokens onto previous full strings elapsed time: 5.161000444786623e-05 seconds

Time to complete while loop for 30 passes: 269.7878909170031 seconds
------------------------------------------------------------------------------------
Total time elapsed time: 289.6417858999994 seconds

Thanks in advance for your help!

2 Likes

decoding is slow for auto-regressive decoder because the tokens are generated one at a time. I can think of only two things to improve speed for generation.

  1. put the model in fp16 (not sure if GPT-2 works with fp6, haven’t tried myself).
  2. See if you can use onnx to speed up inference. Maybe this will help.
3 Likes

@valhalla already gave a great answer if you want to speed-up generation with same parameters.

But you can also change beam-search parameters to speed-up the generation ! (At the cost of text quality, which may be lower). For example, you can :

  • Reduce the beam-size
  • Reduce the size of text to be generated
3 Likes

Post processing such as distillation can also increase speed.