How to parallelize model.generate?

Hi,

I am using model.generate() on the T5 model. I have a for loop where I iterate over the samples one by one. As expected, it is too slow. I know that trainer class has trainer.predict() method but then I can not apply beam decoding right?

I also know that I can pass batches to tokenizer and model.generate() and speed up the training but then still I will have to have a for loop where I iterate over the batches, right?

Is there a way to parallize the generation while using beam strategy and other stuff that model.generate provides?

4 Likes

Hi,

Did you find any good solution?

Thanks.