Zero_division warning in metric.compute

Python: 3.9.7
Datasets: 2.1.0

I’m getting the following warning whenever I run compute() on either the recall or precision metric:

/home/aclifton/anaconda3/envs/rffp/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))

It looks like it’s coming from the metrics defined in sci-kit learn (recall for example). I tried to set zero_division=0 but got the following error:

Traceback (most recent call last):
  File "/home/aclifton/rf_fp/run_training.py", line 278, in <module>
    tmp_metric_result = rffp_data.metrics[metric].compute(average='macro', zero_division=0)
  File "/home/aclifton/anaconda3/envs/rffp/lib/python3.9/site-packages/datasets/metric.py", line 430, in compute
    output = self._compute(**inputs, **compute_kwargs)
TypeError: _compute() got an unexpected keyword argument 'zero_division'

I know it’s only a warning and shouldn’t affect the output of my evaluation loop, but I was curious if there were a way to suppress the warning using zero_division keyword as indicated in the sklearn documentation?

Thanks in advance for your help!

Hi! Can you provide the exact code that produces this error? I can’t reproduce it locally.

@mariosasko Sure thing. Here’s a distilled down snippet:

import datasets
import torch

my_metric = datasets.load_metric('precision')
all_preds = torch.tensor(())
preds_labels = torch.tensor(())

my_model.eval()
for batch in eval_dataloader:
    batch = {k: v for k, v in batch.items()}
    with torch.no_grad():
        outputs = my_model(**batch)
    
    logits = outputs['logits']
    predictions = torch.argmax(logits, dim=-1)
    all_preds = torch.cat((all_preds, predictions))
    preds_labels = torch.cat((preds_labels, batch['labels']))

    my_metric.add_batch(predictions=predictions, references=batch['labels'])

tmp_metric_result = my_metric.compute(average='macro', zero_division=0)