I am using bert for next-sentence prediction on a cpu. I want to call the model twice in a row to select a sentence from a list of sentence pairs. I am using a batch size of 128. My code looks something like this:
def bert_batch_compare(self, prompt1, prompt2):
encoding = self.tokenizer(prompt1, prompt2, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True)
target = torch.ones((1,len(prompt1)), dtype=torch.long)
outputs = self.model(**encoding, next_sentence_label=target)
logits = outputs.logits.detach()
return logits
def call_bert(self):
batch_pattern = []
batch_template = []
batch_input = []
## make batch_pattern, batch_input, patch_template here!
si = self.bert_batch_compare(batch_pattern, batch_input)
## second call to this fn is killed because of memory limitations
sj = self.bert_batch_compare(batch_input, batch_template)
If I lower the batch size to something like 24 it runs, but I’d like to use a larger batch size. I am not doing any training right now. I’m using ‘bert-base-uncased’. During the second call to ‘bert_batch_compare()’ the memory usage increases to 100% and the program crashes. I have 16G to work with. Until that time the code only uses 1.8Gig. I am using linux and python 3.6, along with pytorch 1.8.
Might not be a memory leak but a case of larger batch padding.
If the longest sequence in the first batch is 80 tokens, then that batch will (likely) be padded to 80 items and that may fit into memory. But if the longest sequence in the next batch then contains 250 tokens, then the whole batch is padded to 250 and that might not fit into memory. So verify the length of each individual sample to be sure.
Before the tokenizer I limit the strings in the batches to 80 characters. That works, or seems to, but I have one question. is 80 characters the limit, or is that determined by experiment? Am I right limiting the string length, or should I be trying to limit the number of tokens?
80 was just an example. You should consider your dataset and find out how long your sentence are and where the information typically is. If you have very long sentences but the most information is at the end you need to to truncate the front, etc. Input data analysis is key. If you then have determined what the best length is, you can use that information in your calls to the tokenizer with max_length
.
self.tokenizer(prompt1, prompt2, return_tensors='pt', padding=True, truncation=True, max_length=128)
Note that this max length is about the total amount of splitting subword units, not about the number of total words.
You should first determine this max length according to your data, and then adjust your batch size accordingly.
thank you. that’s very clear.