Speed expectations for production BERT models on CPU vs GPU?

Hi! I’m working with an ML model produced by a researcher. I’m trying to set it up to run economically in production on large volumes of text. I know a lot about production engineering, but next to nothing about ML. I’m getting some results that are surprising to me, and I’m hoping for pointers to explanations and advice.

Right now I have the key code broken out into 3 methods, so it’s easy to profile.

def _tokenize(self, text):
    return self.tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt').to(self.device)

def _run_model(self, model_input):
    return self.model(model_input['input_ids'], token_type_ids=model_input['token_type_ids'])[0]

def _extract_results(self, logits):
    return logits[0][0].item(), logits[0][1].item()

If I run this using my laptop CPU, I get numbers that make sense to me. For 1459 items, those three methods take 196.4 seconds, or about 135 ms per item. About 2.9 seconds is _tokenize, and the rest is _run_model.

When I switch over to my laptop GPU, I get numbers that mystify me. The same data takes 131.8 seconds. 2.5 seconds to tokenize, and running the model takes 20.3 seconds. But extracting the result takes 108.8 seconds!

The _extract_results method costs the same whether I extract one logit or both. The first one that I extract is slow, whether that’s [0] or [1]. The second one is effectively free.

From nvidia-smi, I can see that the GPU is really being used, and my process is using ~850MB of the 2 GB of GPU RAM. So that seems fine. And if it matters, this is a GeForce 940MX on a Thinkpad t470.

Do these numbers make sense to more experienced hands? I was expecting the GPU runs to be much faster, but if I actually want to get the results, it’s only a little faster.


(probably too late to be useful…)

Is the GPU being used for all of the sections? (Maybe your researcher only put the middle section on GPU.)

Does the run-model step actually train and update the model, or does it just calculate answers for the data you entered?