How to apply pruning on a BERT model?

I have trained a BERT model using ktrain (wrapper that uses huggingface transformers library) to recognize emotion on text, it works but it suffers from really slow inference. That makes my model not suitable for a production environment. I have done some research and it seems pruning could help.

The problem is that pruning is not a not a widely used technique and I can not find a simple enough example on Kaggle or Stack that could help me to understand how to use it. Can someone help?

I provide my working code below for reference. My question can be also found on stack https://stackoverflow.com/questions/64445784/how-to-apply-pruning-on-a-bert-model

import pandas as pd
import numpy as np
import preprocessor as p
import emoji
import re
import ktrain
from ktrain import text
from unidecode import unidecode
import nltk

#text preprocessing class
class TextPreprocessing:
    def __init__(self):
        p.set_options(p.OPT.MENTION, p.OPT.URL)
  
    def _punctuation(self,val): 
        val = re.sub(r'[^\w\s]',' ',val)
        val = re.sub('_', ' ',val)
        return val
  
    def _whitespace(self,val):
        return " ".join(val.split())
  
    def _removenumbers(self,val):
        val = re.sub('[0-9]+', '', val)
        return val
  
    def _remove_unicode(self, text):
        text = unidecode(text).encode("ascii")
        text = str(text, "ascii")
        return text  
    
    def _split_to_sentences(self, body_text):
        sentences = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", body_text)
        return sentences
    
    def _clean_text(self,val):
        val = val.lower()
        val = self._removenumbers(val)
        val = p.clean(val)
        val = ' '.join(self._punctuation(emoji.demojize(val)).split())
        val = self._remove_unicode(val)
        val = self._whitespace(val)
        return val
  
    def text_preprocessor(self, body_text):

        body_text_df = pd.DataFrame({"body_text": body_text},index=[1])

        sentence_split_df = body_text_df.copy()

        sentence_split_df["body_text"] = sentence_split_df["body_text"].apply(
            self._split_to_sentences)

        lst_col = "body_text"
        sentence_split_df = pd.DataFrame(
            {
                col: np.repeat(
                    sentence_split_df[col].values, sentence_split_df[lst_col].str.len(
                    )
                )
                for col in sentence_split_df.columns.drop(lst_col)
            }
        ).assign(**{lst_col: np.concatenate(sentence_split_df[lst_col].values)})[
            sentence_split_df.columns
        ]
        
        body_text_df["body_text"] = body_text_df["body_text"].apply(self._clean_text)

        final_df = (
            pd.concat([sentence_split_df, body_text_df])
            .reset_index()
            .drop(columns=["index"])
        )
        
        return final_df["body_text"]

#instantiate data preprocessing object
text1 = TextPreprocessing()

#import data
data_train = pd.read_csv('data_train_v5.csv', encoding='utf8', engine='python')
data_test = pd.read_csv('data_test_v5.csv', encoding='utf8', engine='python')

#clean the data
data_train['Text'] = data_train['Text'].apply(text1._clean_text)
data_test['Text'] = data_test['Text'].apply(text1._clean_text)

X_train = data_train.Text.tolist()
X_test = data_test.Text.tolist()

y_train = data_train.Emotion.tolist()
y_test = data_test.Emotion.tolist()

data = data_train.append(data_test, ignore_index=True)

class_names = ['joy','sadness','fear','anger','neutral']

encoding = {
    'joy': 0,
    'sadness': 1,
    'fear': 2,
    'anger': 3,
    'neutral': 4
}

# Integer values for each class
y_train = [encoding[x] for x in y_train]
y_test = [encoding[x] for x in y_test]

trn, val, preproc = text.texts_from_array(x_train=X_train, y_train=y_train,
                                                                       x_test=X_test, y_test=y_test,
                                                                       class_names=class_names,
                                                                       preprocess_mode='distilbert',
                                                                       maxlen=350)

model = text.text_classifier('distilbert', train_data=trn, preproc=preproc)

learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)

predictor = ktrain.get_predictor(learner.model, preproc)

#save the model on a file for later use
predictor.save("models/bert_model")

message = "This is a happy message"

#cleaning - takes 5ms to run
clean = text1._clean_text(message)

#prediction - takes 325 ms to run
predictor.predict_proba(clean)

I don’t know how to perform pruning. The idea is simple enough - cut out some of the attention heads that are not apparently doing anything useful - but the implementation will be tricky.

Have you considered fine-tuning a DistilBERT or ALBERT model instead?

@rgwatwormhill I have tried fine tuning DistillBERT but this did not solve the inference speed problem. As per this article pruning is probably the only way to increase inference speed which is so much needed in a production environment.

@stamatis,
thanks for the link, that’s interesting,I didn’t know people were pruning weights.

Have you seen this article about pruning heads

This documentation page includes a command to prune particular heads of a model

Bertviz visualisation tool might help to decide which heads to delete, see

As a quick guess, heads where the attention for a word is all to the same word, or is all to the subsequent word, and heads where all the attention from all words is to the SEP token, may not be very useful.

I guess you could also use the command to do some systematic pruning and test the results.

If you are looking to speed up inference then you could give onnx a try. transformers has a script which lets you export models to onnx.

These two blog posts will help

Also , you can try onnx_transformers, which lets you use the same pipeline API but leverages onnx to speed-up inference.

1 Like