Beam_search bottlenecks inference with only 1 used cpu

Hello there, :slight_smile:

I am having trouble optimizing my inference translation pipeline. The bottleneck seems to be beam_search using only 1 cpu whereas 1 gpu and 16 cpus are available.

Here is an overview of cpu usage thanks to py-spy:

I have looked at the code to try to understand what is happening without success.

If you need the lines of the problematic code :

Thread 188015 (active): "MainThread"
    process (transformers/
    beam_search (transformers/
    generate (transformers/
    decorate_context (torch/autograd/
    infer_dataset (
    _CallAndUpdateTrace (fire/
    _Fire (fire/
    Fire (fire/
    main (
    <module> (
    _run_code (
    _run_module_as_main (

What seems strange is that I thought beam_search was using the gpu to be fast (we can see the device=device in the code). I don’t know why the cpu is used here and how to make it use either the gpu or all the available cpus.

Here is a small reproducible code :


tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-mul-en")
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-mul-en").to(device)
dataset_class = TranslationDataset("data.csv", tokenizer)

for input_ids, attention_mask in tqdm.tqdm(DataLoader(dataset_class, 
      tokenized_outputs = model.generate(, 
class TranslationDataset(Dataset):
    def __init__(self, dataset_path, tokenizer):
        self.tokenizer = tokenizer
        self.dataset = pd.read_csv(dataset_path)["text"].values
    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int):
        x = self.tokenizer(self.dataset[idx], return_tensors="pt", max_length=512, padding='max_length')
        return x['input_ids'][0], x['attention_mask'][0]

Could you help me figure out why beam_search is the bottleneck of the pipeline and how to make it work on multi cpus or gpu please ? :slight_smile:

Relevant infos:

  • num_beams > 1
  • num_beam_groups = 1
  • do_sample = False
  • is_constraint_gen_mode = False

Thanks in advance :smiley:

It seems like I am not the only one facing this problem :

Any ideas of solution ? :slight_smile: