I’m looking to summarize a dictionary of (key, value) pairs. However, running sequential CPU summarization has been almost 2x as fast as running batched GPU summarization. How should I speed up the GPU summarization?
Tech Specs:
- MacBook Pro M3 with 8GB of memory
- PyTorch “MPS” backend for the computer’s GPU
- macOS Sonoma 14.6.1
- Python3.11
- transformers==4.44.2
Code:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
class GPUTextSummarizer:
def __init__(self, input_dict):
model_name = "sshleifer/distilbart-cnn-12-6"
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
self.summarizer = pipeline("summarization", model=self.model, tokenizer=self.tokenizer, framework="pt", device=device)
self.summarized_dict = {}
self.input_dict = input_dict
def get_summarized_description(self, keys, batch_size=32, num_workers=4):
values = [self.input_dict[key] for key in keys]
num_tokens_list = [len(self.tokenizer.encode(value)) for value in values]
descriptions_to_summarize = [(key, value) for key, value, num_tokens in zip(keys, values, num_tokens_list) if num_tokens >= 75]
def summarize_batch(batch):
keys, values = zip(*batch)
values = list(values)
print(f"\nSummarizing batch: {', '.join(keys)}")
summarized_values = self.summarizer(
values,
max_length=150,
min_length=75,
truncation=True,
do_sample=False,
)
return list(zip(keys, summarized_values))
batches = [descriptions_to_summarize[i:i + batch_size] for i in range(0, len(descriptions_to_summarize), batch_size)]
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(summarize_batch, batch) for batch in batches]
for future in tqdm(as_completed(futures), total=len(futures)):
for key, summary in future.result():
self.summarized_dict[key] = summary['summary_text']
for key, value in zip(keys, values):
if key not in self.summarized_dict:
self.summarized_dict[key] = value
return self.summarized_dict