how to convert text to word embeddings using bert's pretrained model 'faster'?

I’m trying to get word embeddings for clinical data using microsoft/pubmedbert. I have 3.6 million text rows. Converting texts to vectors for 10k rows takes around 30 minutes. So for 3.6 million rows, it would take around - 180 hours(8days approx).

Is there any method where I can speed up the process?

My code -

from transformers import AutoTokenizer
from transformers import pipeline
model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
tokenizer = AutoTokenizer.from_pretrained(model_name)
classifier = pipeline('feature-extraction',model=model_name, tokenizer=tokenizer)

def lambda_func(row):
    tokens = tokenizer(row['notetext'])
    if len(tokens['input_ids'])>512:
        tokens = re.split(r'\b', row['notetext'])
        tokens= [t for t in tokens if len(t) > 0 ]
        row['notetext'] = ''.join(tokens[:512])
    row['vectors'] = classifier(row['notetext'])[0][0]        
    return row

def process(progress_notes):     
    progress_notes = progress_notes.apply(lambda_func, axis=1)
    return progress_notes

progress_notes = process(progress_notes)
vectors_breadth = 768
vectors_length = len(progress_notes)
vectors_2d = np.reshape(progress_notes['vectors'].to_list(), (vectors_length, vectors_breadth))
vectors_df = pd.DataFrame(vectors_2d)

My progress_notes dataframe looks like -

progress_notes = pd.DataFrame({'id':[1,2,3],'progressnotetype':['Nursing Note', 'Nursing Note', 'Administration Note'], 'notetext': ['Patient\'s skin is grossly intact with exception of skin tear to r inner elbow and r lateral lower leg','Patient with history of Afib with RVR. Patient is incontinent of bowel and bladder.','Give 2 tablet by mouth every 4 hours as needed for Mild to moderate Pain Not to exceed 3 grams in 24 hours']})

Note - 1) I’m running the code on aws ec2 instance r5.8x large(32 CPUs) - I tried using multiprocessing but the code goes into a deadlock because bert takes all my cpu cores.

hi @madhuryadav,

you could try to use onnxruntime for this to get some speed-up. Here’s a notebook which shows, how to use onnxruntime for bert.

This could also help