I was previously using trainer.predict() to run my inference but it does not offer that much flexibility and control over generation. For instance, I want to have more than one generation for each sample and it is not possible with trainer.predict(). So I changed to model.generate(). The problem is now I have to handle batching on my own and can not really use multiple GPUs. Still, I feel like it can not be that slow even under these circumstances. Below I have a small code snippet that does the main job.
My batch size is 128, target_max length is 256. I have 4 brand new A100 GPUs but can only use one because I do not know how to parallize the inference. Trainer.predict() did it automatically. My beam size is 5 and I also return 5 sequences per sample. While using trainer.predict() also with beam size 5, it was pretty fast, so there has to be something that I am doing wrong.
for i, warning in enumerate(all_warning_types):
test_warning_inputs = test_inputs[warning]
test_warning_labels = test_labels[warning]
target_max_length = 256
print(f"number of samples for rule: {warning} is {len(test_warning_inputs)}")
input_ids = tokenizer(
test_warning_inputs,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=target_max_length
).input_ids
target_ids = tokenizer(
test_warning_labels,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=target_max_length
).input_ids
target_ids = np.array(target_ids)
all_beam_output_ids = np.zeros((len(test_warning_inputs) * num_predictions, target_max_length))
for batch_start_idx in range(0, len(test_warning_inputs), args.batch_size):
batch_input_ids = input_ids[batch_start_idx:(batch_start_idx+args.batch_size)]
batch_input_ids = batch_input_ids.to(model.device)
batch_beam_output_ids = model.generate(
batch_input_ids,
max_length=target_max_length,
num_beams=args.beam_size,
early_stopping=False,
num_return_sequences=num_predictions
)
batch_beam_output_ids = batch_beam_output_ids.cpu()
batch_beam_output_ids = np.pad(
batch_beam_output_ids,
((0, 0), (0,
target_max_length - batch_beam_output_ids.shape[1])),
mode="constant"
)
batch_beginning_idx = batch_start_idx * num_predictions
batch_ending_idx = batch_beginning_idx + args.batch_size * num_predictions
all_beam_output_ids[batch_beginning_idx:batch_ending_idx, :] = batch_beam_output_ids