How can I parallelize a metric?

I am currently fine-tuning a language model using a policy-gradient reinforcement learning technique. Instead of a standard loss function, I am using a reward function and the REINFORCE algorithm to teach the model to emulate some desired behaviour.

As part of the reward function, I compute the ROUGE score between a reference sentence and a generated one. To do this I’m currently using a list comprehension that zips together two lists of sentences (ref_batch and pred_batch below) and then calculates the rouge score for each.

The code looks something like this:

from datasets import load_metric
rouge_metric = load_metric("rouge")

def get_rouge_score(ref, pred):
    return rouge_metric.compute(rouge_types=["rougeL"], predictions=[pred], references=[ref])['rougeL'].mid.fmeasure 

rouge_scores = torch.tensor([get_rouge_score(ref,pred) for ref,pred in zip(ref_batch, pred_batch)], device=device)

The problem with this is that it is very slow. The list comprehension iterates through examples one by one and uses the CPU to do the operation. By contrast, the rest of the training loop runs using tensor cores on a GPU. Hence this step is a significant bottleneck; profiling on the training step shows that this step alone takes up ~60% of the training time.

So my question: how could I parallelize this step, or even make it faster another way? If there’s a way to calculate scores on the GPU, that would be awesome. If not, is there an easy way I can use multiple CPU cores for this?

I’m also able to change the metric from ROUGE to another one that is more able to be parallelized, if that helps.

Thanks

Hi ! To speed up the processing you can pass keep_in_memory=True to load_metric to keep each sample in memory (by default it writes them on disk to save memory, but since you’re passing the examples one by one to compute you don’t need this). This should speed up your computation significantly.

Moreover feel free to use python multiprocessing to parallelize this (using multiprocessing.Pool and Pool.imap for example)

Thanks! Setting keep_in_memory=True did indeed speed things up drastically. This change alone made my training loop 16% faster and my eval loop 25% faster, so that’s really cool to see.

I noticed too that load_metric has a keyword num_process - can I just set this number higher to easily do multiprocessing?

This parameter can be used if you do one distributed evaluation (see the docs here).

However in your case you are doing n different evaluations of the metric, so you have to parallelize your code by yourself

1 Like