CPU faster than MacBook GPU for Summarization

I’m looking to summarize a dictionary of (key, value) pairs. However, running sequential CPU summarization has been almost 2x as fast as running batched GPU summarization. How should I speed up the GPU summarization?

Tech Specs:

  • MacBook Pro M3 with 8GB of memory
  • PyTorch “MPS” backend for the computer’s GPU
  • macOS Sonoma 14.6.1
  • Python3.11
  • transformers==4.44.2

Code:

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

class GPUTextSummarizer:
    def __init__(self, input_dict):
        model_name = "sshleifer/distilbart-cnn-12-6"
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
        self.summarizer = pipeline("summarization", model=self.model, tokenizer=self.tokenizer, framework="pt", device=device)
        self.summarized_dict = {}
        self.input_dict = input_dict

    def get_summarized_description(self, keys, batch_size=32, num_workers=4):
        values = [self.input_dict[key] for key in keys]
        num_tokens_list = [len(self.tokenizer.encode(value)) for value in values]

        descriptions_to_summarize = [(key, value) for key, value, num_tokens in zip(keys, values, num_tokens_list) if num_tokens >= 75]

        def summarize_batch(batch):
            keys, values = zip(*batch)
            values = list(values)

            print(f"\nSummarizing batch: {', '.join(keys)}")
            summarized_values = self.summarizer(
                values,
                max_length=150,
                min_length=75,
                truncation=True,
                do_sample=False,
            )
            return list(zip(keys, summarized_values))

        batches = [descriptions_to_summarize[i:i + batch_size] for i in range(0, len(descriptions_to_summarize), batch_size)]

        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            futures = [executor.submit(summarize_batch, batch) for batch in batches]
            for future in tqdm(as_completed(futures), total=len(futures)):
                for key, summary in future.result():
                    self.summarized_dict[key] = summary['summary_text']

        for key, value in zip(keys, values):
            if key not in self.summarized_dict:
                self.summarized_dict[key] = value

        return self.summarized_dict