[PYTORCH] Trace on CPU and use on GPU

Hi All,
Is is possible to trace the GPT/Bert models on CPU and use that saved traced model on GPU?

I see a constant called “device” in the traced graph which seems persists in the saved models. This causes the model to be usable only on device where its traced, ie., if its traced on GPU devices then the saved JIT model is only usable on GPU hosts. I am using PyTorch 1.4 and transformers 0.3.2

Am I missing something here?

Thanks,

Not sure if I’m right but I think you can specify the device to load the traced/saved model to in the load function with the map_location parameter

https://pytorch.org/docs/master/generated/torch.jit.load.html#torch.jit.load

@jlaute: Thanks for the quick response. I might be missing something simple, it would be very helpful to understand this.

The sample code that I am running is attached below:

Transformers version == 3.0.2
Torch version == 1.4.0

from transformers import OpenAIGPTTokenizer, OpenAIGPTModel
import torch

tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
model = OpenAIGPTModel.from_pretrained('openai-gpt')

inputs = torch.tensor([tokenizer.encode("Hello, my dog is cute")])
outputs = model(inputs)
print(outputs)

print("To CUDA:")
inputs = inputs.to("cuda")
model = model.to("cuda")
traced_model = torch.jit.trace(model, (inputs,))
torch.jit.save(traced_model, "openai_gpt_cuda.pt")
print(traced_model.graph)
print("\n")
print("Load model onto CPU")
loaded = torch.jit.load("openai_gpt_cuda.pt", map_location=torch.device("cpu"))
inputs = inputs.to("cpu")
print("\n")
print(loaded.graph)
outputs = loaded(inputs)
print(outputs)

Error seen

Traceback (most recent call last):
  File "gpt.py", line 23, in <module>
    outputs = loaded(inputs)
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_index_select
The above operation failed in interpreter.
Traceback (most recent call last):
Serialized   File "code/__torch__/torch/nn/modules/module/___torch_mangle_147.py", line 35
    position_ids = torch.arange(_20, dtype=4, layout=0, device=torch.device("cuda:0"), pin_memory=False)
    input0 = torch.view(torch.unsqueeze(position_ids, 0), [-1, _19])
    _21 = torch.add((_14).forward(input, ), (_13).forward(input0, ), alpha=1)
                                             ~~~~~~~~~~~~ <--- HERE
    input1 = torch.add(_21, CONSTANTS.c0, alpha=1)
    _22 = (_12).forward(input1, )
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/functional.py(1484): embedding
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/sparse.py(114): forward
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(516): _slow_forward
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(530): __call__
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/transformers/modeling_openai.py(433): forward
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(516): _slow_forward
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(530): __call__
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/__init__.py(1034): trace_module
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/__init__.py(882): trace
gpt.py(14): <module>
Serialized   File "code/__torch__/torch/nn/modules/module/___torch_mangle_0.py", line 8, in forward
  def forward(self: __torch__.torch.nn.modules.module.___torch_mangle_0.Module,
    input: Tensor) -> Tensor:
    position_embeds = torch.embedding(self.weight, input, -1, False, False)
                      ~~~~~~~~~~~~~~~ <--- HERE
    return position_embeds

The above operation failed in interpreter.
Traceback (most recent call last):

Bumping this topic up… Any help would be much appreciated.

Backtracking my answer on the Github issue here for anyone who might be facing the same kind of issue:

The problem arises in modeling_openai.py when the user do not provide the position_ids function argument thus leading to the inner position_ids being created during the forward call. This is fine in classic PyTorch because forward is actually evaluated at each call. When it comes to tracing, this is an issue, because the device specified in the forward to actually create the tensor will be hardcoded and you can actually see it in the generated graph:

%input.1 : Tensor = aten::view(%input_ids.1, %64) 
%140 : Device = prim::Constant[value="cuda:0"]() 
%position_ids.1 : Tensor = aten::arange(%59, %67, %45, %140, %70) 
%73 : Tensor = aten::unsqueeze(%position_ids.1, %45)

Above you can see %140 is a constant which value is actually set to "cuda:0" and then, it is reused to create the %position_ids.1 tensor through aten::arange(..., %140, ...) which of course leads to the error you’re seeing.

I’ll have a fix to generate the position_ids buffer correctly registered at the Module initialisation and not during forward, so it should be correctly handled by the map_location parameter while exporting.

2 Likes