im running inference on a simple bert model on a fixed sample size of 1200 datapoints for a classification task
These are the total time taken in processing the entire dataset under different experiment scenarios
Note the time listed is the total sum of (time taken by each indvl batch)
A. Using DataLoader with pinned memory true (b=32)
-
Tokenization Time: 0:00:00.983126
-
CPU GPU Data Transfer Time: 0:00:00.338249
-
Forward Pass Time: 0:00:01.528260
B. Using DataLoader with pinned memory true (b=256)
## Total Time Taken with batch_size: 256 0:00:02.825832
-
Tokenization Time: 0:00:01.069770
-
CPU GPU Data Transfer Time: 0:00:01.219674
-
Forward Pass Time: 0:00:00.464684
**C. Using DataLoader with pinned memory true (b=1024)
Total Time Taken with batch_size: (1024) 0:00:03.747645**
-
Tokenization Time: 0:00:00.930315
-
CPU GPU Data Transfer Time: 0:00:02.331700
-
Forward Pass Time: 0:00:00.069030
im not understanding why does the total data transfer time(inputs.to(device)) is higher for a higher batch size
(ie. why does transfering data of size 1024 - 1 time to the gpu takes longer than transferring data of batch_size 32 - 32 times)
How do i solve this cpu bottleneck issue
Some info on the inference setup code:
```
for batch_inputs in tqdm(inference_loader):
# print(ā#####ā, batch_inputs)
s = datetime.now()
inputs = tokenizer(batch_inputs, return_tensors=āptā, padding=True, truncation=True, max_length=128)
e = datetime.now()
tokenization_time.append(e-s)
s = datetime.now()
inputs = inputs.to(device)
e = datetime.now()
datatransfer_time.append(e-s)
s = datetime.now()
logits = model(**inputs).logits
e = datetime.now()
inference_time.append(e-s)
outputs.append(logits)
outputs = torch.cat(outputs, dim=0)
print(outputs.shape)
results['labels'] = [id2label[indx] for indx in outputs.argmax(dim=-1).tolist()]
results['scores'] = outputs.softmax(dim=-1).tolist()
from datetime import timedelta
print("Tokenization Time: ", sum(tokenization_time, timedelta(0)))
print("DataLoad Time: ", sum(datatransfer_time, timedelta(0)))
print("Inference Time: ", sum(inference_time, timedelta(0)))
```