Cannot use Datasets.map on multi-gpu during evaluation

Hi,

I am new to the Huggingface community and currently facing difficulty in running an example evaluation script on multi-gpu. I am using this LED model here. However, I am not able to run this on multi-gpu. The code is using only one gpu. I tried various combinations like converting model to model = torch.nn.DataParallel(model).cuda() but still it is using only one GPU. All other codes are running perfectly fine on multi-gpu including the LED funetuning script here.

I am relatively comfortable with PyTorch and have tried various versions including using native PyTorch but nothing works.

Here is the minimal code extracted out from the above notebook.

from datasets import load_dataset, load_metric

test_dataset = load_dataset("scientific_papers", "arxiv", split="test")

from transformers import LEDForConditionalGeneration, LEDTokenizer

tokenizer = LEDTokenizer.from_pretrained("allenai/led-large-16384-arxiv")
model = LEDForConditionalGeneration.from_pretrained("allenai/led-large-16384-arxiv").to("cuda").half()

import torch

def generate_answer(batch):
    inputs_dict = tokenizer(batch["article"], padding="max_length", max_length=16384, return_tensors="pt", truncation=True)
    input_ids = inputs_dict.input_ids.to("cuda")
    attention_mask = inputs_dict.attention_mask.to("cuda")

    global_attention_mask = torch.zeros_like(attention_mask)
    # put global attention on <s> token
    global_attention_mask[:, 0] = 1

    predicted_abstract_ids = model.generate(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, max_length=512, num_beams=4)
    batch["predicted_abstract"] = tokenizer.batch_decode(predicted_abstract_ids, skip_special_tokens=True)
    return batch

dataset_small = test_dataset.select(range(600))
result_small = dataset_small.map(generate_answer, batched=True, batch_size=2)

rouge = load_metric("rouge")
rouge.compute(predictions=result_small["predicted_abstract"], references=result_small["abstract"], rouge_types=["rouge2"])["rouge2"].mid

Hi! There is a paragraph on how to use map in a multi-GPU env in the datasets docs here. Or have you already tried without success?