DETR: use torchscripted model on both cpu and gpu

Hi,

I have fine-tuned a Detr model on a custom dataset and want to test deployment options. In order to do this I compiled the model using torchscript.trace.
I noticed that I have to compile it on a cpu instance if I want to deploy it on a cpu instance while I need to deploy it on a gpu instance in order to be able to deploy it on a gpu. Trying to deploy the cpu version on a gpu leads to a RuntimeError:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
The relevant stack trace is below:

2022-07-25T16:45:25,360 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - Traceback of TorchScript, serialized code (most recent call last):
2022-07-25T16:45:25,360 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -   File "code/__torch__/Detr_model.py", line 11, in forward
2022-07-25T16:45:25,360 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     pixel_values: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
2022-07-25T16:45:25,361 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     model = self.model
2022-07-25T16:45:25,361 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _0, _1, _2, _3, = (model).forward(pixel_values, )
2022-07-25T16:45:25,361 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -                        ~~~~~~~~~~~~~~ <--- HERE
2022-07-25T16:45:25,361 [INFO ] W-9000-model_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 2039
2022-07-25T16:45:25,362 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     return (_0, _1, _2, _3)
2022-07-25T16:45:25,362 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -   File "code/__torch__/transformers/models/detr/modeling_detr.py", line 14, in forward
2022-07-25T16:45:25,363 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     class_labels_classifier = self.class_labels_classifier
2022-07-25T16:45:25,363 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     model = self.model
2022-07-25T16:45:25,363 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _0, _1, = (model).forward(pixel_values, )
2022-07-25T16:45:25,363 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -                ~~~~~~~~~~~~~~ <--- HERE
2022-07-25T16:45:25,364 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _2 = (class_labels_classifier).forward(_0, )
2022-07-25T16:45:25,364 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _3 = torch.sigmoid((bbox_predictor).forward(_0, ))
2022-07-25T16:45:25,364 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -   File "code/__torch__/transformers/models/detr/modeling_detr.py", line 47, in forward
2022-07-25T16:45:25,364 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     width = ops.prim.NumToTensor(torch.size(pixel_values, 3))
2022-07-25T16:45:25,365 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     pixel_mask = torch.ones([_0, _6, int(width)], dtype=6, layout=None, device=torch.device("cpu"), pin_memory=False)
2022-07-25T16:45:25,365 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _7 = (backbone).forward(pixel_values, pixel_mask, )
2022-07-25T16:45:25,365 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -           ~~~~~~~~~~~~~~~~~ <--- HERE
2022-07-25T16:45:25,366 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _8, _9, _10, = _7
2022-07-25T16:45:25,362 [INFO ] W-9000-model_1.0 ACCESS_LOG - /169.254.178.2:46734 "POST /invocations HTTP/1.1" 500 2053
2022-07-25T16:45:25,366 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _11 = torch.flatten((input_projection).forward(_8, ), 2)
2022-07-25T16:45:25,366 [INFO ] W-9000-model_1.0 TS_METRICS - Requests5XX.Count:1|#Level:Host|#hostname:container-0.local,timestamp:null
2022-07-25T16:45:25,366 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -   File "code/__torch__/transformers/models/detr/modeling_detr.py", line 75, in forward
2022-07-25T16:45:25,366 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _21 = (position_embedding).forward1(_16, )
2022-07-25T16:45:25,366 [INFO ] W-9000-model_1.0 TS_METRICS - QueueTime.ms:0|#Level:Host|#hostname:container-0.local,timestamp:null
2022-07-25T16:45:25,367 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _22 = (position_embedding).forward2(_17, )
2022-07-25T16:45:25,367 [INFO ] W-9000-model_1.0 TS_METRICS - WorkerThreadTime.ms:8|#Level:Host|#hostname:container-0.local,timestamp:null
2022-07-25T16:45:25,367 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _23 = (position_embedding).forward3(_18, )
2022-07-25T16:45:25,367 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
2022-07-25T16:45:25,367 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     return (_19, torch.to(_23, 6), _18)
2022-07-25T16:45:25,368 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - class DetrTimmConvEncoder(Module):
2022-07-25T16:45:25,368 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -   File "code/__torch__/transformers/models/detr/modeling_detr.py", line 311, in forward3
2022-07-25T16:45:25,368 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     x_embed3 = torch.mul(_176, CONSTANTS.c2)
2022-07-25T16:45:25,368 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     tensor1 = torch.arange(128, dtype=6, layout=0, device=torch.device("cpu"), pin_memory=False)
2022-07-25T16:45:25,368 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _177 = torch.div(tensor1, CONSTANTS.c3, rounding_mode="floor")
2022-07-25T16:45:25,368 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -            ~~~~~~~~~ <--- HERE
2022-07-25T16:45:25,368 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _178 = torch.div(torch.mul(_177, CONSTANTS.c3), CONSTANTS.c4)
2022-07-25T16:45:25,369 [INFO ] W-9000-model_1.0-stdout MODEL_LOG -     _179 = torch.to(CONSTANTS.c5, torch.device("cpu"), 6)
2022-07-25T16:45:25,369 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - 
2022-07-25T16:45:25,369 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - Traceback of TorchScript, original code (most recent call last):
2022-07-25T16:45:25,369 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/transformers/pytorch_utils.py(31): torch_int_div
2022-07-25T16:45:25,369 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/transformers/models/detr/modeling_detr.py(423): forward
2022-07-25T16:45:25,369 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
2022-07-25T16:45:25,370 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
2022-07-25T16:45:25,370 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/transformers/models/detr/modeling_detr.py(378): forward
2022-07-25T16:45:25,370 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
2022-07-25T16:45:25,370 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
2022-07-25T16:45:25,371 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/transformers/models/detr/modeling_detr.py(1252): forward
2022-07-25T16:45:25,371 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
2022-07-25T16:45:25,371 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
2022-07-25T16:45:25,371 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/transformers/models/detr/modeling_detr.py(1400): forward
2022-07-25T16:45:25,371 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
2022-07-25T16:45:25,372 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
2022-07-25T16:45:25,372 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /cc-repos/ml-research/detr-page-segmentation/scripts/Detr_model.py(81): forward
2022-07-25T16:45:25,372 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
2022-07-25T16:45:25,372 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
2022-07-25T16:45:25,372 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/jit/_trace.py(958): trace_module
2022-07-25T16:45:25,373 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/torch/jit/_trace.py(741): trace
2022-07-25T16:45:25,373 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /var/folders/pn/ds8zfqbs0_z8t0yd8ysphwjw0000gn/T/ipykernel_46418/1824559375.py(7): <cell line: 1>
2022-07-25T16:45:25,373 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3398): run_code
2022-07-25T16:45:25,373 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3338): run_ast_nodes
2022-07-25T16:45:25,374 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3135): run_cell_async
2022-07-25T16:45:25,374 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
2022-07-25T16:45:25,374 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/IPython/core/interactiveshell.py(2936): _run_cell
2022-07-25T16:45:25,375 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/IPython/core/interactiveshell.py(2881): run_cell
2022-07-25T16:45:25,375 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/ipykernel/zmqshell.py(528): run_cell
2022-07-25T16:45:25,375 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/ipykernel/ipkernel.py(383): do_execute
2022-07-25T16:45:25,376 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/ipykernel/kernelbase.py(728): execute_request
2022-07-25T16:45:25,376 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/ipykernel/kernelbase.py(404): dispatch_shell
2022-07-25T16:45:25,376 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/ipykernel/kernelbase.py(497): process_one
2022-07-25T16:45:25,376 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/ipykernel/kernelbase.py(508): dispatch_queue
2022-07-25T16:45:25,377 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/asyncio/events.py(81): _run
2022-07-25T16:45:25,377 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/asyncio/base_events.py(1859): _run_once
2022-07-25T16:45:25,377 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/asyncio/base_events.py(570): run_forever
2022-07-25T16:45:25,378 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/tornado/platform/asyncio.py(199): start
2022-07-25T16:45:25,379 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/ipykernel/kernelapp.py(712): start
2022-07-25T16:45:25,384 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/traitlets/config/application.py(976): launch_instance
2022-07-25T16:45:25,384 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/site-packages/ipykernel_launcher.py(17): <module>
2022-07-25T16:45:25,384 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/runpy.py(87): _run_code
2022-07-25T16:45:25,385 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - /.pyenv/versions/3.8.10/lib/python3.8/runpy.py(194): _run_module_as_main
2022-07-25T16:45:25,385 [INFO ] W-9000-model_1.0-stdout MODEL_LOG - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I have found an issue which sounds similar: [PYTORCH] Trace on CPU and use on GPU. There they suggest to use variable map_location in the torch.jit.load function to specify the device to which the model should be loaded. I’m doing this but still get above error.

Is there something else I need to do to make that work?

P.S.: Separately I also tried to compile the model for use on aws-inferentia instances and got a Divide by zero error. This might/ might not be related. Here’s a link to the raised issue for it Compilation error for :hugs: Detr model: TVMError: Check failed: pb->value != 0 (0 vs. 0) : Divide by zero in case it is relevant.

Thanks for the help!