Multi-GPU inference with Luke NER not working

Hi everyone!

I’m trying to run Luke for inference on multiple gpus using DataParallel but I’m encountering an error that I can’t seem to resolve. Can you help ?

Here is my code

from transformers import LukeTokenizer, LukeForEntitySpanClassification
import torch

luke_model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")

# Getting inputs (type : transformers.tokenization_utils_base.BatchEncoding)

inputs = []
for i in tqdm(range(10)):
  input_filepath = df["input_filepath"].iloc[i]
  handle = open(input_filepath,'rb')
  input_tensor = pickle.load(handle)
  
  inputs.append(input_tensor)


device_ids = [0,1,2,3]
model= torch.nn.DataParallel(luke_model)
model.to("cuda")
replicas = nn.parallel.replicate(model,device_ids)

inputs_dp = nn.parallel.scatter(inputs[:4], device_ids)
outputs = nn.parallel.parallel_apply(replicas, inputs_dp)

The error I get is :

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<command-1863732679336681> in <module>
     21 
     22 inputs_dp = nn.parallel.scatter(inputs[:4], device_ids)
---> 23 outputs = nn.parallel.parallel_apply(replicas, inputs_dp)

/databricks/python/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
     84         output = results[i]
     85         if isinstance(output, ExceptionWrapper):
---> 86             output.reraise()
     87         outputs.append(output)
     88     return outputs

/databricks/python/lib/python3.8/site-packages/torch/_utils.py in reraise(self)
    432             # instantiate since we don't know how to
    433             raise RuntimeError(msg) from None
--> 434         raise exception
    435 
    436 

AttributeError: Caught AttributeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/databricks/python/lib/python3.8/site-packages/torch/_utils.py", line 434, in reraise
    raise exception
AttributeError: Caught AttributeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/databricks/python/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 250, in __getattr__
    return self.data[item]
KeyError: 'size'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/transformers/models/luke/modeling_luke.py", line 1583, in forward
    outputs = self.luke(
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/transformers/models/luke/modeling_luke.py", line 977, in forward
    input_shape = input_ids.size()
  File "/databricks/python/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 252, in __getattr__
    raise AttributeError
AttributeError

Thanks in advance!