Model.generate() is extremely slow while using beam search

I was previously using trainer.predict() to run my inference but it does not offer that much flexibility and control over generation. For instance, I want to have more than one generation for each sample and it is not possible with trainer.predict(). So I changed to model.generate(). The problem is now I have to handle batching on my own and can not really use multiple GPUs. Still, I feel like it can not be that slow even under these circumstances. Below I have a small code snippet that does the main job.

My batch size is 128, target_max length is 256. I have 4 brand new A100 GPUs but can only use one because I do not know how to parallize the inference. Trainer.predict() did it automatically. My beam size is 5 and I also return 5 sequences per sample. While using trainer.predict() also with beam size 5, it was pretty fast, so there has to be something that I am doing wrong.

CC: @sgugger @valhalla

for i, warning in enumerate(all_warning_types):
    test_warning_inputs = test_inputs[warning]
    test_warning_labels = test_labels[warning]
    target_max_length = 256 
    print(f"number of samples for rule: {warning} is {len(test_warning_inputs)}")

    input_ids = tokenizer(
        test_warning_inputs, 
        return_tensors="pt", 
        truncation=True, 
        padding="max_length", 
        max_length=target_max_length
    ).input_ids
    target_ids = tokenizer(
        test_warning_labels, 
        return_tensors="pt", 
        truncation=True, 
        padding="max_length", 
        max_length=target_max_length
    ).input_ids
    target_ids = np.array(target_ids)

    all_beam_output_ids = np.zeros((len(test_warning_inputs) * num_predictions, target_max_length))
    for batch_start_idx in range(0, len(test_warning_inputs), args.batch_size):
        batch_input_ids = input_ids[batch_start_idx:(batch_start_idx+args.batch_size)]
        batch_input_ids = batch_input_ids.to(model.device)
        batch_beam_output_ids = model.generate(
            batch_input_ids, 
            max_length=target_max_length, 
            num_beams=args.beam_size, 
            early_stopping=False, 
            num_return_sequences=num_predictions
        )
        
        batch_beam_output_ids = batch_beam_output_ids.cpu()
        batch_beam_output_ids = np.pad(
            batch_beam_output_ids, 
            ((0, 0), (0, 
            target_max_length - batch_beam_output_ids.shape[1])), 
            mode="constant"
        )

        batch_beginning_idx = batch_start_idx * num_predictions
        batch_ending_idx = batch_beginning_idx + args.batch_size * num_predictions
        all_beam_output_ids[batch_beginning_idx:batch_ending_idx, :] = batch_beam_output_ids